diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index a3ea9c3172009..c13bf47eacb94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -32,7 +32,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple @@ -184,11 +186,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols") - ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols") - ParamValidators.checkExclusiveParams(this, "splits", "splitsArray") + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8fbb0e1b2a3ba..bd15a8492e3be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -27,7 +27,6 @@ import scala.collection.mutable import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.slf4j.LoggerFactory import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} @@ -167,8 +166,6 @@ private[ml] object Param { @DeveloperApi object ParamValidators { - private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass) - /** (private[param]) Default validation always return true */ private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true @@ -254,21 +251,69 @@ object ParamValidators { } /** - * Checks that only one of the params passed as arguments is set. If this is not true, an - * `IllegalArgumentException` is raised. + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. */ - def checkExclusiveParams(model: Params, params: String*): Unit = { - val (existingParams, nonExistingParams) = params.partition(model.hasParam) - if (nonExistingParams.nonEmpty) { - val pronoun = if (nonExistingParams.size == 1) "It" else "They" - LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " + - s"exclusive params. $pronoun don't exist for the specified model the model.") + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } } - if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) { - val paramString = existingParams.mkString("`", "`, `", "`") - throw new IllegalArgumentException(s"$paramString are exclusive, " + - "but more than one among them are set.") + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 94d6125c4c98e..6ecab7cbf6968 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -435,24 +435,23 @@ object ParamsSuite extends SparkFunSuite { } /** - * Checks that the class throws an exception in case multiple exclusive params are set + * Checks that the class throws an exception in case multiple exclusive params are set. * The params to be checked are passed as arguments with their value. - * The checks are performed only if all the passed params are defined for the given model. */ - def testExclusiveParams(model: Params, dataset: Dataset[_], + def testExclusiveParams( + model: Params, + dataset: Dataset[_], paramsAndValues: (String, Any)*): Unit = { - val params = paramsAndValues.map(_._1) - if (params.forall(model.hasParam)) { - paramsAndValues.foreach { case (paramName, paramValue) => - model.set(model.getParam(paramName), paramValue) - } - val e = intercept[IllegalArgumentException] { - model match { - case t: Transformer => t.transform(dataset) - case e: Estimator[_] => e.fit(dataset) - } + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + val e = intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) } - assert(e.getMessage.contains("are exclusive, but more than one")) } + assert(e.getMessage.contains("are exclusive, but more than one")) } }