Skip to content

Commit

Permalink
Merge pull request #1 from jkbradley/mgaido91-SPARK-22799
Browse files Browse the repository at this point in the history
strengthened requirements about exclusive Params for single and multicolumn support
  • Loading branch information
mgaido91 authored Jan 20, 2018
2 parents 25b9bd4 + 18bbf61 commit d9d25b0
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 33 deletions.
15 changes: 11 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*
* Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
Expand Down Expand Up @@ -184,11 +186,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "splits", "splitsArray")
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))

if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
Expand Down
75 changes: 60 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import scala.collection.mutable

import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
Expand Down Expand Up @@ -167,8 +166,6 @@ private[ml] object Param {
@DeveloperApi
object ParamValidators {

private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass)

/** (private[param]) Default validation always return true */
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true

Expand Down Expand Up @@ -254,21 +251,69 @@ object ParamValidators {
}

/**
* Checks that only one of the params passed as arguments is set. If this is not true, an
* `IllegalArgumentException` is raised.
* Utility for Param validity checks for Transformers which have both single- and multi-column
* support. This utility assumes that `inputCol` indicates single-column usage and
* that `inputCols` indicates multi-column usage.
*
* This checks to ensure that exactly one set of Params has been set, and it
* raises an `IllegalArgumentException` if not.
*
* @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been
* set. This does not need to include `inputCol`.
* @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been
* set. This does not need to include `inputCols`.
*/
def checkExclusiveParams(model: Params, params: String*): Unit = {
val (existingParams, nonExistingParams) = params.partition(model.hasParam)
if (nonExistingParams.nonEmpty) {
val pronoun = if (nonExistingParams.size == 1) "It" else "They"
LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " +
s"exclusive params. $pronoun don't exist for the specified model the model.")
def checkSingleVsMultiColumnParams(
model: Params,
singleColumnParams: Seq[Param[_]],
multiColumnParams: Seq[Param[_]]): Unit = {
val name = s"${model.getClass.getSimpleName} $model"

def checkExclusiveParams(
isSingleCol: Boolean,
requiredParams: Seq[Param[_]],
excludedParams: Seq[Param[_]]): Unit = {
val badParamsMsgBuilder = new mutable.StringBuilder()

val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
.map(_.name).mkString(", ")
if (mustUnsetParams.nonEmpty)
badParamsMsgBuilder ++=
s"The following Params are not applicable and should not be set: $mustUnsetParams."

val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
.map(_.name).mkString(", ")
if (mustSetParams.nonEmpty)
badParamsMsgBuilder ++=
s"The following Params must be defined but are not set: $mustSetParams."

val badParamsMsg = badParamsMsgBuilder.toString()

if (badParamsMsg.nonEmpty) {
val errPrefix = if (isSingleCol) {
s"$name has the inputCol Param set for single-column transform."
} else {
s"$name has the inputCols Param set for multi-column transform."
}
throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
}
}

if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) {
val paramString = existingParams.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, " +
"but more than one among them are set.")
val inputCol = model.getParam("inputCol")
val inputCols = model.getParam("inputCols")

if (model.isSet(inputCol)) {
require(!model.isSet(inputCols), s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but both are set.")

checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
excludedParams = multiColumnParams)
} else if (model.isSet(inputCols)) {
checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
excludedParams = singleColumnParams)
} else {
throw new IllegalArgumentException(s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
}
}
}
Expand Down
27 changes: 13 additions & 14 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,23 @@ object ParamsSuite extends SparkFunSuite {
}

/**
* Checks that the class throws an exception in case multiple exclusive params are set
* Checks that the class throws an exception in case multiple exclusive params are set.
* The params to be checked are passed as arguments with their value.
* The checks are performed only if all the passed params are defined for the given model.
*/
def testExclusiveParams(model: Params, dataset: Dataset[_],
def testExclusiveParams(
model: Params,
dataset: Dataset[_],
paramsAndValues: (String, Any)*): Unit = {
val params = paramsAndValues.map(_._1)
if (params.forall(model.hasParam)) {
paramsAndValues.foreach { case (paramName, paramValue) =>
model.set(model.getParam(paramName), paramValue)
}
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
val m = model.copy(ParamMap.empty)
paramsAndValues.foreach { case (paramName, paramValue) =>
m.set(m.getParam(paramName), paramValue)
}
val e = intercept[IllegalArgumentException] {
m match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
assert(e.getMessage.contains("are exclusive, but more than one"))
}
assert(e.getMessage.contains("are exclusive, but more than one"))
}
}

0 comments on commit d9d25b0

Please sign in to comment.