Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Feb 22, 2025
1 parent 73a8fcc commit 2a7face
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,20 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
}
}
)

val sconf = dataset.sparkSession.conf
val rmmEnabled: Boolean = try {
dataset.sparkSession.conf.get("spark.rapids.memory.gpu.pooling.enabled").toBoolean
sconf.get("spark.rapids.memory.gpu.pooling.enabled").toBoolean &&
sconf.get("spark.rapids.memory.gpu.pool").trim.toLowerCase != "none"
} catch {
case _: Throwable => false // Any exception will return false
}

(rdd, Map("use_rmm" -> rmmEnabled).asInstanceOf[Map[String, AnyRef]])
val configs = if (rmmEnabled) {
Map("use_rmm" -> rmmEnabled).asInstanceOf[Map[String, AnyRef]]
} else {
Map.empty[String, AnyRef]
}
(rdd, configs)
}

override def transform[M <: XGBoostModel[M]](model: XGBoostModel[M],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,11 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
val features = Array("c1", "c2")
val classifier = new XGBoostClassifier().setDevice("cuda").setFeaturesCol(features)
val (_, configs) = PluginUtils.getPlugin.get.buildRddWatches(classifier, df)
assert(configs("use_rmm") == false)
assert(configs.isEmpty)
}

val conf = new SparkConf().set("spark.rapids.memory.gpu.pooling.enabled", "true")
.set("spark.rapids.memory.gpu.pool", "ASYNC")
withGpuSparkSession(conf) { spark =>
import spark.implicits._

Expand Down

0 comments on commit 2a7face

Please sign in to comment.