Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49418][CONNECT][SQL] Shared Session Thread Locals #48374

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -525,6 +525,8 @@ class SparkSession private[sql] (
}
}

override private[sql] def isUsable: Boolean = client.isSessionValid

implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
Expand All @@ -533,7 +535,9 @@ class SparkSession private[sql] (

// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends api.SparkSessionCompanion with Logging {
object SparkSession extends api.BaseSparkSessionCompanion with Logging {
override private[sql] type Session = SparkSession

private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
Expand All @@ -549,29 +553,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
override def load(c: Configuration): SparkSession = create(c)
})

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
Expand Down Expand Up @@ -624,17 +605,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
}

/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}

/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
Expand Down Expand Up @@ -781,71 +751,12 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
}
}

/**
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}

/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}

/**
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = super.getActiveSession

/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = super.getDefaultSession

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
/** @inheritdoc */
override def active: SparkSession = super.active
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.util.control.NonFatal

import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}

import org.apache.spark.SparkException
import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkSerDeUtils

Expand Down Expand Up @@ -113,7 +114,7 @@ class SparkSessionSuite extends ConnectFunSuite {
SparkSession.clearActiveSession()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
intercept[IllegalStateException](SparkSession.active)
intercept[SparkException](SparkSession.active)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically a behavior change.


// Create a session
val session1 = SparkSession.builder().remote(connectionString1).getOrCreate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.canUseSession"),

// SparkSession#implicits
ProblemFilters.exclude[DirectMissingMethodProblem](
Expand Down
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"),

// SPARK-49418: Consolidate thread local handling in sql/api
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"),
) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++
loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++
loggingExcludes("org.apache.spark.sql.SparkSession#Builder")
Expand Down
Loading