Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49418][CONNECT][SQL] Shared Session Thread Locals #48374

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -494,6 +494,8 @@ class SparkSession private[sql] (
}
}

override private[sql] def isUsable: Boolean = client.isSessionValid

implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
Expand All @@ -502,7 +504,9 @@ class SparkSession private[sql] (

// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends api.SparkSessionCompanion with Logging {
object SparkSession extends api.BaseSparkSessionCompanion with Logging {
override private[sql] type Session = SparkSession

private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
Expand All @@ -518,29 +522,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
override def load(c: Configuration): SparkSession = create(c)
})

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
Expand Down Expand Up @@ -593,17 +574,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
}

/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}

/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
Expand Down Expand Up @@ -750,71 +720,12 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
}
}

/**
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}

/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}

/**
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = super.getActiveSession

/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = super.getDefaultSession

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
/** @inheritdoc */
override def active: SparkSession = super.active
}
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"),

// SPARK-49418: Consolidate thread local handling in sql/api
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"),
) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++
loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++
loggingExcludes("org.apache.spark.sql.SparkSession#Builder")
Expand Down
166 changes: 165 additions & 1 deletion sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import _root_.java.io.Closeable
import _root_.java.lang
import _root_.java.net.URI
import _root_.java.util
import _root_.java.util.concurrent.atomic.AtomicReference

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.sql.{Encoder, Row, RuntimeConfig}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -561,9 +563,19 @@ abstract class SparkSession extends Serializable with Closeable {
* @since 2.0.0
*/
def stop(): Unit = close()

/**
* Check to see if the session is still usable.
*
* In classic this means that the underlying `SparkContext` has been shut down. In Connect this
* means the connection to the server has been closed.
*/
private[sql] def isUsable: Boolean
}

object SparkSession extends SparkSessionCompanion {
type Session = SparkSession

private[this] val companion: SparkSessionCompanion = {
val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession")
val mirror = scala.reflect.runtime.currentMirror
Expand All @@ -573,12 +585,97 @@ object SparkSession extends SparkSessionCompanion {

/** @inheritdoc */
override def builder(): SparkSessionBuilder = companion.builder()

/** @inheritdoc */
override def setActiveSession(session: SparkSession): Unit =
companion.setActiveSession(session.asInstanceOf[companion.Session])

/** @inheritdoc */
override def clearActiveSession(): Unit = companion.clearActiveSession()

/** @inheritdoc */
override def setDefaultSession(session: SparkSession): Unit =
companion.setDefaultSession(session.asInstanceOf[companion.Session])

/** @inheritdoc */
override def clearDefaultSession(): Unit = companion.clearDefaultSession()

/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = companion.getActiveSession

/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = companion.getDefaultSession
}

/**
* Companion of a [[SparkSession]].
* Interface for a [[SparkSession]] Companion. The companion is responsible for building the
* session, and managing the active (thread local) and default (global) SparkSessions.
*/
private[sql] abstract class SparkSessionCompanion {
private[sql] type Session <: SparkSession

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* a SparkSession with an isolated session, instead of the global (first created) context.
*
* @since 2.0.0
*/
def setActiveSession(session: Session): Unit

/**
* Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
* return the first created context instead of a thread-local override.
*
* @since 2.0.0
*/
def clearActiveSession(): Unit

/**
* Sets the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def setDefaultSession(session: Session): Unit

/**
* Clears the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def clearDefaultSession(): Unit

/**
* Returns the active SparkSession for the current thread, returned by the builder.
*
* @note
* Return None, when calling this function on executors
*
* @since 2.2.0
*/
def getActiveSession: Option[Session]

/**
* Returns the default SparkSession that is returned by the builder.
*
* @note
* Return None, when calling this function on executors
*
* @since 2.2.0
*/
def getDefaultSession: Option[Session]

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 2.4.0
*/
def active: Session = {
getActiveSession.getOrElse(
getDefaultSession.getOrElse(
throw SparkException.internalError("No active or default Spark session found")))
}

/**
* Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]].
Expand All @@ -588,6 +685,73 @@ private[sql] abstract class SparkSessionCompanion {
def builder(): SparkSessionBuilder
}

/**
* Abstract class for [[SparkSession]] companions. This implements active and default session
* management.
*/
private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompanion {

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[Session]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[Session]

/** @inheritdoc */
def setActiveSession(session: Session): Unit = {
activeThreadSession.set(session)
}

/** @inheritdoc */
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}

/** @inheritdoc */
def setDefaultSession(session: Session): Unit = {
defaultSession.set(session)
}

/** @inheritdoc */
def clearDefaultSession(): Unit = {
defaultSession.set(null.asInstanceOf[Session])
}

/** @inheritdoc */
def getActiveSession: Option[Session] = Option(activeThreadSession.get)

/** @inheritdoc */
def getDefaultSession: Option[Session] = Option(defaultSession.get)

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or they are not usable.
*/
protected def setDefaultAndActiveSession(session: Session): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.isUsable) {
// Update `defaultSession` if it is null or the contained session is not usable. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
val active = getActiveSession
if (active.isEmpty || !active.get.isUsable) {
setActiveSession(session)
}
}

/**
* When the session is closed remove it from active and default.
*/
private[sql] def onSessionClose(session: Session): Unit = {
defaultSession.compareAndSet(session, null.asInstanceOf[Session])
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}
}

/**
* Builder for [[SparkSession]].
*/
Expand Down
Loading