From bb0c0d29f4eec137bbd90ae068a7f8a30c92ea9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Dec 2017 19:09:32 +0100 Subject: [PATCH] address comments --- .../apache/spark/ml/feature/Bucketizer.scala | 22 +++++++++---------- .../org/apache/spark/ml/param/params.scala | 8 +++++++ .../spark/ml/param/shared/sharedParams.scala | 20 +++++++++++++++++ 3 files changed, 39 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e945909cd4394..4e0d647dd1d71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -34,9 +34,9 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -140,15 +140,15 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * by `inputCol`. A warning will be printed if both are set. */ private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) || - isSet(inputCol) && isSet(outputCols)) { - throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " + - "only supports setting either `inputCol` or `inputCols`.") - } else if (isSet(inputCols)) { - true - } else { - false + inputColsSanityCheck() + outputColsSanityCheck() + if (isSet(inputCol) && isSet(splitsArray)) { + raiseIncompatibleParamsException("inputCol", "splitsArray") + } + if (isSet(inputCols) && isSet(splits)) { + raiseIncompatibleParamsException("inputCols", "splits") } + isSet(inputCols) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa0..c5af53e91f4cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -834,6 +834,14 @@ trait Params extends Identifiable with Serializable { } to } + + final def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { + throw new IllegalArgumentException( + s""" + |Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports + |setting either `$paramName1` or `$paramName2`. + """.stripMargin) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f18..931744d2b5945 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -236,6 +236,16 @@ trait HasInputCols extends Params { /** @group getParam */ final def getInputCols: Array[String] = $(inputCols) + + final def inputColsSanityCheck(): Unit = { + this match { + case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) => + raiseIncompatibleParamsException("inputCols", "inputCol") + case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) => + raiseIncompatibleParamsException("inputCols", "outputCol") + case _ => + } + } } /** @@ -272,6 +282,16 @@ trait HasOutputCols extends Params { /** @group getParam */ final def getOutputCols: Array[String] = $(outputCols) + + final def outputColsSanityCheck(): Unit = { + this match { + case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) => + raiseIncompatibleParamsException("outputCols", "inputCol") + case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) => + raiseIncompatibleParamsException("outputCols", "outputCol") + case _ => + } + } } /**