Skip to content

Commit

Permalink
Eagerly initialize RapidsShuffleManager for SPARK-45762 [databricks] (#…
Browse files Browse the repository at this point in the history
…11904)

Fixes #11107

- Initialize RapidsShuffleManager on construction for Spark 4.0 and Databricks 14.3
- Disable lazy initialization and conf validation in the case above

Signed-off-by: Gera Shegalov <gshegalov@nvidia.com>
  • Loading branch information
gerashegalov authored Dec 24, 2024
1 parent 0f702cd commit 2b7a0e2
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids
import java.util.Locale

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.ShuffleManagerShimUtils

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -105,21 +106,25 @@ object GpuShuffleEnv extends Logging {
//
def initShuffleManager(): Unit = {
val shuffleManager = SparkEnv.get.shuffleManager
shuffleManager match {
case rapidsShuffleManager: RapidsShuffleManagerLike =>
rapidsShuffleManager.initialize
case _ =>
val rsmLoaderViaShuffleManager = shuffleManager.getClass.getSuperclass.getInterfaces
.collectFirst {
case c if c.getName == classOf[RapidsShuffleManagerLike].getName => c.getClassLoader
}
val rsmLoaderDirect = classOf[RapidsShuffleManagerLike].getClassLoader

throw new IllegalStateException(s"Cannot initialize the RAPIDS Shuffle Manager " +
s"${shuffleManager}! Expected: an instance of RapidsShuffleManagerLike loaded by " +
s"${rsmLoaderDirect}. Actual: ${shuffleManager} tagged with RapidsShuffleManagerLike " +
s"loaded by: ${rsmLoaderViaShuffleManager}"
)
if (ShuffleManagerShimUtils.eagerlyInitialized) {
// skip deferred init
} else {
shuffleManager match {
case rapidsShuffleManager: RapidsShuffleManagerLike =>
rapidsShuffleManager.initialize
case _ =>
val rsmLoaderViaShuffleManager = shuffleManager.getClass.getSuperclass.getInterfaces
.collectFirst {
case c if c.getName == classOf[RapidsShuffleManagerLike].getName => c.getClassLoader
}
val rsmLoaderDirect = classOf[RapidsShuffleManagerLike].getClassLoader

throw new IllegalStateException(s"Cannot initialize the RAPIDS Shuffle Manager " +
s"${shuffleManager}! Expected: an instance of RapidsShuffleManagerLike loaded by " +
s"${rsmLoaderDirect}. Actual: ${shuffleManager} tagged with RapidsShuffleManagerLike " +
s"loaded by: ${rsmLoaderViaShuffleManager}"
)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "350db143"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.$_spark.version.classifier_

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


/*** spark-rapids-shim-json-lines
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object ShuffleManagerShimUtils {
def eagerlyInitialized = false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.$_spark.version.classifier_

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.ProxyRapidsShuffleInternalManagerBase

/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
sealed class RapidsShuffleManager(
conf: SparkConf,
isDriver: Boolean
) extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) {
initialize
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "350db143"}
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object ShuffleManagerShimUtils {
def eagerlyInitialized = true
}

0 comments on commit 2b7a0e2

Please sign in to comment.