Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Feb 22, 2025
1 parent 14b5ba7 commit 3a45504
Showing 1 changed file with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
import spark.implicits._

val df = Seq(
(1.0f, 2.0f, 0.0f),
(5.0f, 6.0f, 0.1f)
(1.0f, 2.0f, 0),
(5.0f, 6.0f, 1)
).toDF("c1", "c2", "label")
val features = Array("c1", "c2")
val classifier = new XGBoostClassifier().setDevice("cuda").setFeaturesCol(features)
Expand All @@ -200,13 +200,19 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {
import spark.implicits._

val df = Seq(
(1.0f, 2.0f, 0.0f),
(5.0f, 6.0f, 0.1f)
(1.0f, 2.0f, 0),
(5.0f, 6.0f, 1)
).toDF("c1", "c2", "label")
val features = Array("c1", "c2")
val classifier = new XGBoostClassifier().setDevice("cuda").setFeaturesCol(features)
val classifier = new XGBoostClassifier()
.setDevice("cuda")
.setFeaturesCol(features)
.setNumRound(2)
val (_, configs) = PluginUtils.getPlugin.get.buildRddWatches(classifier, df)
assert(configs("use_rmm") == true)

// No exception
classifier.fit(df)
}
}

Expand Down

0 comments on commit 3a45504

Please sign in to comment.