Skip to content

Commit

Permalink
[SPARK-22799][ML] Bucketizer should throw exception if single- and mu…
Browse files Browse the repository at this point in the history
…lti-column params are both set
  • Loading branch information
mgaido91 committed Dec 15, 2017
1 parent e58f275 commit 8f3581c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* by `inputCol`. A warning will be printed if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
if (isSet(inputCols) && isSet(inputCol)) {
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
"`Bucketizer` only map one column specified by `inputCol`")
false
if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) ||
isSet(inputCol) && isSet(outputCols)) {
throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " +
"only supports setting either `inputCol` or `inputCols`.")
} else if (isSet(inputCols)) {
true
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}

test("Both inputCol and inputCols are set") {
val bucket = new Bucketizer()
val feature1 = Array(-0.5, -0.3, 0.0, 0.2)
val feature2 = Array(-0.3, -0.2, 0.5, 0.0)
val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2")

val invalid1 = new Bucketizer()
.setInputCol("feature1")
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

// When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
assert(bucket.isBucketizeMultipleColumns() == false)
val invalid2 = new Bucketizer()
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

val invalid3 = new Bucketizer()
.setInputCol("feature1")
.setSplits(Array(-0.5, 0.0, 0.5))
.setOutputCols(Array("result1", "result2"))

Seq(invalid1, invalid2, invalid3).foreach { bucketizer =>
// When both inputCol and inputCols are set, we throw Exception.
val e = intercept[IllegalArgumentException] {
bucketizer.transform(df)
}
assert(e.getMessage.contains("Both `inputCol` and `inputCols` are set"))
}
}
}

Expand Down

0 comments on commit 8f3581c

Please sign in to comment.