Skip to content

Commit

Permalink
[SPARK-22799][ML] Bucketizer should throw exception if single- and mu…
Browse files Browse the repository at this point in the history
…lti-column params are both set

## What changes were proposed in this pull request?

Currently there is a mixed situation when both single- and multi-column are supported. In some cases exceptions are thrown, in others only a warning log is emitted. In this discussion https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049, the decision was to throw an exception.

The PR throws an exception in `Bucketizer`, instead of logging a warning.

## How was this patch tested?

modified UT

Author: Marco Gaido <marcogaido91@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #19993 from mgaido91/SPARK-22799.
  • Loading branch information
mgaido91 authored and Nick Pentreath committed Jan 26, 2018
1 parent d172181 commit cd3956d
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 45 deletions.
44 changes: 19 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*
* Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
* only used for single column usage, and `splitsArray` is for multiple columns.
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
* columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
Expand Down Expand Up @@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Determines whether this `Bucketizer` is going to map multiple columns. If and only if
* `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
* by `inputCol`. A warning will be printed if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
if (isSet(inputCols) && isSet(inputCol)) {
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
"`Bucketizer` only map one column specified by `inputCol`")
false
} else if (isSet(inputCols)) {
true
} else {
false
}
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema)

val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
val (inputColumns, outputColumns) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
Expand All @@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}

val seqOfSplits = if (isBucketizeMultipleColumns()) {
val seqOfSplits = if (isSet(inputCols)) {
$(splitsArray).toSeq
} else {
Seq($(splits))
Expand Down Expand Up @@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
if (isBucketizeMultipleColumns()) {
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))

if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
Expand Down
69 changes: 69 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,75 @@ object ParamValidators {
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
value.length > lowerBound
}

/**
* Utility for Param validity checks for Transformers which have both single- and multi-column
* support. This utility assumes that `inputCol` indicates single-column usage and
* that `inputCols` indicates multi-column usage.
*
* This checks to ensure that exactly one set of Params has been set, and it
* raises an `IllegalArgumentException` if not.
*
* @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been
* set. This does not need to include `inputCol`.
* @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been
* set. This does not need to include `inputCols`.
*/
def checkSingleVsMultiColumnParams(
model: Params,
singleColumnParams: Seq[Param[_]],
multiColumnParams: Seq[Param[_]]): Unit = {
val name = s"${model.getClass.getSimpleName} $model"

def checkExclusiveParams(
isSingleCol: Boolean,
requiredParams: Seq[Param[_]],
excludedParams: Seq[Param[_]]): Unit = {
val badParamsMsgBuilder = new mutable.StringBuilder()

val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
.map(_.name).mkString(", ")
if (mustUnsetParams.nonEmpty) {
badParamsMsgBuilder ++=
s"The following Params are not applicable and should not be set: $mustUnsetParams."
}

val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
.map(_.name).mkString(", ")
if (mustSetParams.nonEmpty) {
badParamsMsgBuilder ++=
s"The following Params must be defined but are not set: $mustSetParams."
}

val badParamsMsg = badParamsMsgBuilder.toString()

if (badParamsMsg.nonEmpty) {
val errPrefix = if (isSingleCol) {
s"$name has the inputCol Param set for single-column transform."
} else {
s"$name has the inputCols Param set for multi-column transform."
}
throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
}
}

val inputCol = model.getParam("inputCol")
val inputCols = model.getParam("inputCols")

if (model.isSet(inputCol)) {
require(!model.isSet(inputCols), s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but both are set.")

checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
excludedParams = multiColumnParams)
} else if (model.isSet(inputCols)) {
checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
excludedParams = singleColumnParams)
} else {
throw new IllegalArgumentException(s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
}
}
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer1.isBucketizeMultipleColumns())

bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
Seq("result1", "result2"),
Expand All @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result"))
.setSplitsArray(Array(splits(0)))

assert(bucketizer2.isBucketizeMultipleColumns())

withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer2.transform(badDF1).collect()
Expand Down Expand Up @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer.isBucketizeMultipleColumns())

BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
Seq("expected1", "expected2"))
Expand All @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer.isBucketizeMultipleColumns())

bucketizer.setHandleInvalid("keep")
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
Expand Down Expand Up @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
assert(t.isBucketizeMultipleColumns())
testDefaultReadWrite(t)
}

Expand All @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))

assert(bucket.isBucketizeMultipleColumns())

val pl = new Pipeline()
.setStages(Array(bucket))
.fit(df)
Expand Down Expand Up @@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}

test("Both inputCol and inputCols are set") {
val bucket = new Bucketizer()
.setInputCol("feature1")
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

// When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
assert(bucket.isBucketizeMultipleColumns() == false)
test("assert exception is thrown if both multi-column and single-column params are set") {
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("inputCols", Array("feature1", "feature2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
("outputCols", Array("result1", "result2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))))

// this should fail because at least one of inputCol and inputCols must be set
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"),
("splits", Array(-0.5, 0.0, 0.5)))

// the following should fail because not all the params are set
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("outputCol", "result1"))
ParamsSuite.testExclusiveParams(new Bucketizer, df,
("inputCols", Array("feature1", "feature2")),
("outputCols", Array("result1", "result2")))
}
}

Expand Down
22 changes: 22 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.ml.param
import java.io.{ByteArrayOutputStream, ObjectOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.MyParams
import org.apache.spark.sql.Dataset

class ParamsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite {
require(copyReturnType === obj.getClass,
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}

/**
* Checks that the class throws an exception in case multiple exclusive params are set.
* The params to be checked are passed as arguments with their value.
*/
def testExclusiveParams(
model: Params,
dataset: Dataset[_],
paramsAndValues: (String, Any)*): Unit = {
val m = model.copy(ParamMap.empty)
paramsAndValues.foreach { case (paramName, paramValue) =>
m.set(m.getParam(paramName), paramValue)
}
intercept[IllegalArgumentException] {
m match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
}
}
}

0 comments on commit cd3956d

Please sign in to comment.