Skip to content

Commit

Permalink
[SPARK-45762][CORE] Support shuffle managers defined in user jars by …
Browse files Browse the repository at this point in the history
…changing startup order

### What changes were proposed in this pull request?
As reported here https://issues.apache.org/jira/browse/SPARK-45762, `ShuffleManager` instances defined in a user jar cannot be used in all cases, unless specified in the `extraClassPath`. We would like to avoid adding extra configurations if this instance is already included in a jar passed via `--jars`.

Proposed changes:

Refactor code so we initialize the `ShuffleManager` later, after jars have been localized. This is especially necessary in the executor, where we would need to move this initialization until after the `replClassLoader` is updated with jars passed in `--jars`.

Before this change, the `ShuffleManager` is instantiated at `SparkEnv` creation. Having to instantiate the `ShuffleManager` this early doesn't work, because user jars have not been localized in all scenarios, and we will fail to load the `ShuffleManager` defined in `--jars`. We propose moving the `ShuffleManager` instantiation to `SparkContext` on the driver, and `Executor`.

### Why are the changes needed?
This is not a new API but a change of startup order. The changed are needed to improve the user experience for the user by reducing extra configurations depending on how a spark application is launched.

### Does this PR introduce _any_ user-facing change?
Yes, but it's backwards compatible. Users no longer need to specify a `ShuffleManager` jar in `extraClassPath`, but they are able to if they desire.

This change is not binary compatible with Spark 3.5.0 (see MIMA comments below). I have added a rule to MimaExcludes to handle it 970bff4

### How was this patch tested?
Added a unit test showing that a test `ShuffleManager` is available after `--jars` are passed, but not without (using local-cluster mode).

Tested manually with standalone mode, local-cluster mode, yarn client and cluster mode, k8s.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #43627 from abellina/shuffle_manager_initialization_order.

Authored-by: Alessandro Bellina <abellina@nvidia.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
abellina authored and Mridul Muralidharan committed Nov 17, 2023
1 parent 8147620 commit 7c146c9
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 28 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ class SparkContext(config: SparkConf) extends Logging {

// Initialize any plugins before the task scheduler is initialized.
_plugins = PluginContainer(this, _resources.asJava)
_env.initializeShuffleManager()

// Create and start the scheduler
val (sched, ts) = SparkContext.createTaskScheduler(this, master)
Expand Down
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
package org.apache.spark

import java.io.File
import java.util.Locale

import scala.collection.concurrent
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Properties

import com.google.common.base.Preconditions
import com.google.common.cache.CacheBuilder
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -63,7 +63,6 @@ class SparkEnv (
val closureSerializer: Serializer,
val serializerManager: SerializerManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val securityManager: SecurityManager,
Expand All @@ -72,6 +71,12 @@ class SparkEnv (
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

// We initialize the ShuffleManager later in SparkContext and Executor to allow
// user jars to define custom ShuffleManagers.
private var _shuffleManager: ShuffleManager = _

def shuffleManager: ShuffleManager = _shuffleManager

@volatile private[spark] var isStopped = false

/**
Expand Down Expand Up @@ -100,7 +105,9 @@ class SparkEnv (
isStopped = true
pythonWorkers.values.foreach(_.stop())
mapOutputTracker.stop()
shuffleManager.stop()
if (shuffleManager != null) {
shuffleManager.stop()
}
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
Expand Down Expand Up @@ -186,6 +193,12 @@ class SparkEnv (
releasePythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker)
}

private[spark] def initializeShuffleManager(): Unit = {
Preconditions.checkState(null == _shuffleManager,
"Shuffle manager already initialized to %s", _shuffleManager)
_shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER)
}
}

object SparkEnv extends Logging {
Expand Down Expand Up @@ -356,16 +369,6 @@ object SparkEnv extends Logging {
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
"tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER)
val shuffleMgrClass =
shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName)
val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager](
shuffleMgrClass, conf, isDriver)

val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores)

val blockManagerPort = if (isDriver) {
Expand Down Expand Up @@ -403,7 +406,7 @@ object SparkEnv extends Logging {
None
}, blockManagerInfo,
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
shuffleManager,
_shuffleManager = null,
isDriver)),
registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
Expand All @@ -416,6 +419,10 @@ object SparkEnv extends Logging {
advertiseAddress, blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint)

// NB: blockManager is not valid until initialize() is called later.
// SPARK-45762 introduces a change where the ShuffleManager is initialized later
// in the SparkContext and Executor, to allow for custom ShuffleManagers defined
// in user jars. The BlockManager uses a lazy val to obtain the
// shuffleManager from the SparkEnv.
val blockManager = new BlockManager(
executorId,
rpcEnv,
Expand All @@ -424,7 +431,7 @@ object SparkEnv extends Logging {
conf,
memoryManager,
mapOutputTracker,
shuffleManager,
_shuffleManager = null,
blockTransferService,
securityManager,
externalShuffleClient)
Expand Down Expand Up @@ -463,7 +470,6 @@ object SparkEnv extends Logging {
closureSerializer,
serializerManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
blockManager,
securityManager,
Expand Down
13 changes: 10 additions & 3 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,21 @@ private[spark] class Executor(
}
updateDependencies(initialUserFiles, initialUserJars, initialUserArchives, defaultSessionState)

// Plugins need to load using a class loader that includes the executor's user classpath.
// Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175).
// Plugins and shuffle managers need to load using a class loader that includes the executor's
// user classpath. Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175 and SPARK-45762).
private val plugins: Option[PluginContainer] =
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
PluginContainer(env, resources.asJava)
}

// Skip local mode because the ShuffleManager is already initialized
if (!isLocal) {
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
env.initializeShuffleManager()
}
}

metricsPoller.start()

private[executor] def numRunningTasks: Int = runningTasks.size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

package org.apache.spark.shuffle

import org.apache.spark.{ShuffleDependency, TaskContext}
import java.util.Locale

import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.internal.config
import org.apache.spark.util.Utils

/**
* Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
Expand Down Expand Up @@ -94,3 +98,23 @@ private[spark] trait ShuffleManager {
/** Shut down this ShuffleManager. */
def stop(): Unit
}

/**
* Utility companion object to create a ShuffleManager given a spark configuration.
*/
private[spark] object ShuffleManager {
def create(conf: SparkConf, isDriver: Boolean): ShuffleManager = {
Utils.instantiateSerializerOrShuffleManager[ShuffleManager](
getShuffleManagerClassName(conf), conf, isDriver)
}

def getShuffleManagerClassName(conf: SparkConf): String = {
val shortShuffleMgrNames = Map(
"sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
"tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)

val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER)
shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName)
}
}

14 changes: 11 additions & 3 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,17 @@ private[spark] class BlockManager(
val conf: SparkConf,
memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
private val _shuffleManager: ShuffleManager,
val blockTransferService: BlockTransferService,
securityManager: SecurityManager,
externalBlockStoreClient: Option[ExternalBlockStoreClient])
extends BlockDataManager with BlockEvictionHandler with Logging {

// We initialize the ShuffleManager later in SparkContext and Executor, to allow
// user jars to define custom ShuffleManagers, as such `_shuffleManager` will be null here
// (except for tests) and we ask for the instance from the SparkEnv.
private lazy val shuffleManager = Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager)

// same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)`
private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined
private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
Expand Down Expand Up @@ -587,12 +592,15 @@ private[spark] class BlockManager(

private def registerWithExternalShuffleServer(): Unit = {
logInfo("Registering executor with local external shuffle service.")
// we obtain the class name from the configuration, instead of the ShuffleManager
// instance because the ShuffleManager has not been created at this point.
val shuffleMgrClass = ShuffleManager.getShuffleManagerClassName(conf)
val shuffleManagerMeta =
if (Utils.isPushBasedShuffleEnabled(conf, isDriver = isDriver, checkSerializer = false)) {
s"${shuffleManager.getClass.getName}:" +
s"${shuffleMgrClass}:" +
s"${diskBlockManager.getMergeDirectoryAndAttemptIDJsonString()}}}"
} else {
shuffleManager.getClass.getName
shuffleMgrClass
}
val shuffleConfig = new ExecutorShuffleInfo(
diskBlockManager.localDirsString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import scala.util.control.NonFatal

import com.google.common.cache.CacheBuilder

import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext}
import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config.RDD_CACHE_VISIBILITY_TRACKING_ENABLED
Expand All @@ -55,10 +55,15 @@ class BlockManagerMasterEndpoint(
externalBlockStoreClient: Option[ExternalBlockStoreClient],
blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo],
mapOutputTracker: MapOutputTrackerMaster,
shuffleManager: ShuffleManager,
private val _shuffleManager: ShuffleManager,
isDriver: Boolean)
extends IsolatedThreadSafeRpcEndpoint with Logging {

// We initialize the ShuffleManager later in SparkContext and Executor, to allow
// user jars to define custom ShuffleManagers, as such `_shuffleManager` will be null here
// (except for tests) and we ask for the instance from the SparkEnv.
private lazy val shuffleManager = Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager)

// Mapping from executor id to the block manager's local disk directories.
private val executorIdToLocalDirs =
CacheBuilder
Expand Down
77 changes: 77 additions & 0 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,83 @@ class SparkSubmitSuite
runSparkSubmit(args)
}

test("SPARK-45762: The ShuffleManager plugin to use can be defined in a user jar") {
val shuffleManagerBody = """
|@Override
|public <K, V, C> org.apache.spark.shuffle.ShuffleHandle registerShuffle(
| int shuffleId,
| org.apache.spark.ShuffleDependency<K, V, C> dependency) {
| throw new java.lang.UnsupportedOperationException("This is a test ShuffleManager!");
|}
|
|@Override
|public <K, V> org.apache.spark.shuffle.ShuffleWriter<K, V> getWriter(
| org.apache.spark.shuffle.ShuffleHandle handle,
| long mapId,
| org.apache.spark.TaskContext context,
| org.apache.spark.shuffle.ShuffleWriteMetricsReporter metrics) {
| throw new java.lang.UnsupportedOperationException("This is a test ShuffleManager!");
|}
|
|@Override
|public <K, C> org.apache.spark.shuffle.ShuffleReader<K, C> getReader(
| org.apache.spark.shuffle.ShuffleHandle handle,
| int startMapIndex,
| int endMapIndex,
| int startPartition,
| int endPartition,
| org.apache.spark.TaskContext context,
| org.apache.spark.shuffle.ShuffleReadMetricsReporter metrics) {
| throw new java.lang.UnsupportedOperationException("This is a test ShuffleManager!");
|}
|
|@Override
|public boolean unregisterShuffle(int shuffleId) {
| throw new java.lang.UnsupportedOperationException("This is a test ShuffleManager!");
|}
|
|@Override
|public org.apache.spark.shuffle.ShuffleBlockResolver shuffleBlockResolver() {
| throw new java.lang.UnsupportedOperationException("This is a test ShuffleManager!");
|}
|
|@Override
|public void stop() {
|}
""".stripMargin

val tempDir = Utils.createTempDir()
val compiledShuffleManager = TestUtils.createCompiledClass(
"TestShuffleManager",
tempDir,
"",
null,
Seq.empty,
Seq("org.apache.spark.shuffle.ShuffleManager"),
shuffleManagerBody)

val jarUrl = TestUtils.createJar(
Seq(compiledShuffleManager),
new File(tempDir, "testplugin.jar"))

val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val argsBase = Seq(
"--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
"--master", "local-cluster[1,1,1024]",
"--conf", "spark.shuffle.manager=TestShuffleManager",
"--conf", "spark.ui.enabled=false")

val argsError = argsBase :+ unusedJar.toString
// check process error exit code
assertResult(1)(runSparkSubmit(argsError, expectFailure = true))

val argsSuccess = (argsBase ++ Seq("--jars", jarUrl.toString)) :+ unusedJar.toString
// check process success exit code
assertResult(0)(
runSparkSubmit(argsSuccess, expectFailure = false))
}

private def testRemoteResources(
enableHttpFs: Boolean,
forceDownloadSchemes: Seq[String] = Nil): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits {
args: Seq[String],
sparkHomeOpt: Option[String] = None,
timeout: Span = defaultSparkSubmitTimeout,
isSparkTesting: Boolean = true): Unit = {
isSparkTesting: Boolean = true,
expectFailure: Boolean = false): Int = {
val sparkHome = sparkHomeOpt.getOrElse(
sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")))
val history = ArrayBuffer.empty[String]
Expand Down Expand Up @@ -77,7 +78,7 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits {

try {
val exitCode = failAfter(timeout) { process.waitFor() }
if (exitCode != 0) {
if (exitCode != 0 && !expectFailure) {
// include logs in output. Note that logging is async and may not have completed
// at the time this exception is raised
Thread.sleep(1000)
Expand All @@ -90,6 +91,7 @@ trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits {
""".stripMargin
}
}
exitCode
} catch {
case to: TestFailedDueToTimeoutException =>
val historyLog = history.mkString("\n")
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI")
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI"),
// [SPARK-45762][CORE] Support shuffle managers defined in user jars by changing startup order
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this")
)

// Default exclude rules
Expand Down

0 comments on commit 7c146c9

Please sign in to comment.