Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*
* Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
* only used for single column usage, and `splitsArray` is for multiple columns.
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
* columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
Expand Down Expand Up @@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Determines whether this `Bucketizer` is going to map multiple columns. If and only if
* `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
* by `inputCol`. A warning will be printed if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
if (isSet(inputCols) && isSet(inputCol)) {
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
"`Bucketizer` only map one column specified by `inputCol`")
false
} else if (isSet(inputCols)) {
true
} else {
false
}
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema)

val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
val (inputColumns, outputColumns) = if (isSet(inputCols)) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
Expand All @@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}

val seqOfSplits = if (isBucketizeMultipleColumns()) {
val seqOfSplits = if (isSet(inputCols)) {
$(splitsArray).toSeq
} else {
Seq($(splits))
Expand Down Expand Up @@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
if (isBucketizeMultipleColumns()) {
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))

if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
Expand Down
69 changes: 69 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,75 @@ object ParamValidators {
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
value.length > lowerBound
}

/**
* Utility for Param validity checks for Transformers which have both single- and multi-column
* support. This utility assumes that `inputCol` indicates single-column usage and
* that `inputCols` indicates multi-column usage.
*
* This checks to ensure that exactly one set of Params has been set, and it
* raises an `IllegalArgumentException` if not.
*
* @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been
* set. This does not need to include `inputCol`.
* @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been
* set. This does not need to include `inputCols`.
*/
def checkSingleVsMultiColumnParams(
model: Params,
singleColumnParams: Seq[Param[_]],
multiColumnParams: Seq[Param[_]]): Unit = {
val name = s"${model.getClass.getSimpleName} $model"

def checkExclusiveParams(
isSingleCol: Boolean,
requiredParams: Seq[Param[_]],
excludedParams: Seq[Param[_]]): Unit = {
val badParamsMsgBuilder = new mutable.StringBuilder()

val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
.map(_.name).mkString(", ")
if (mustUnsetParams.nonEmpty) {
badParamsMsgBuilder ++=
s"The following Params are not applicable and should not be set: $mustUnsetParams."
}

val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
.map(_.name).mkString(", ")
if (mustSetParams.nonEmpty) {
badParamsMsgBuilder ++=
s"The following Params must be defined but are not set: $mustSetParams."
}

val badParamsMsg = badParamsMsgBuilder.toString()

if (badParamsMsg.nonEmpty) {
val errPrefix = if (isSingleCol) {
s"$name has the inputCol Param set for single-column transform."
} else {
s"$name has the inputCols Param set for multi-column transform."
}
throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to check other exclusive params, e.g., inputCol and splitsArray or inputCols and splits, why not just have a method like:

def checkExclusiveParams(model: Params, params: String*): Unit = {
  if (params.filter(model.isSet(_)).size > 1) {
    val paramString = params.mkString("`", "`, `", "`")
    throw new IllegalArgumentException(s"$paramString are exclusive, but more than one among them are set.")
  }
}
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "inputCol", "splitsArray")
ParamValidators.checkExclusiveParams(this, "inputCols", "splits")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this method too in #20146.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use that method once merged, thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if #20146 will get merged for 2.3 - but I think we must merge this PR for 2.3 because I'd prefer not to have this inconsistency in param error handling between QuantileDiscretizer and Bucketizer. This is a relatively small change, so we can merge it into the branch if we move it quickly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on #20146 (comment) from @WeichenXu123, I think #20146 cannot get merged for 2.3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this method looks good to you, maybe you can just copy it from #20146 to use here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick @viirya in order to address https://github.com/apache/spark/pull/19993/files#r161682506, I was thinking to let this method as it is (just renaming it as per @viirya suggestion) and only adding an additionalExclusiveParams: (String, String)* argument to the function. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @viirya's method is simpler and more general, so why not use it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya your actual method in #20146 is slightly different (see here). Is that the best version to use?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick Yes. I didn't test the method posted here. The model possibly doesn't have the params, so we need to check it with model.hasParam. Please use the method in #20146.


val inputCol = model.getParam("inputCol")
val inputCols = model.getParam("inputCols")

if (model.isSet(inputCol)) {
require(!model.isSet(inputCols), s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but both are set.")

checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
excludedParams = multiColumnParams)
} else if (model.isSet(inputCols)) {
checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
excludedParams = singleColumnParams)
} else {
throw new IllegalArgumentException(s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
}
}
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer1.isBucketizeMultipleColumns())

bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
Seq("result1", "result2"),
Expand All @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result"))
.setSplitsArray(Array(splits(0)))

assert(bucketizer2.isBucketizeMultipleColumns())

withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
intercept[SparkException] {
bucketizer2.transform(badDF1).collect()
Expand Down Expand Up @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer.isBucketizeMultipleColumns())

BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
Seq("expected1", "expected2"))
Expand All @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(splits)

assert(bucketizer.isBucketizeMultipleColumns())

bucketizer.setHandleInvalid("keep")
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
Seq("result1", "result2"),
Expand Down Expand Up @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
assert(t.isBucketizeMultipleColumns())
testDefaultReadWrite(t)
}

Expand All @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCols(Array("result1", "result2"))
.setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))

assert(bucket.isBucketizeMultipleColumns())

val pl = new Pipeline()
.setStages(Array(bucket))
.fit(df)
Expand Down Expand Up @@ -401,15 +390,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}

test("Both inputCol and inputCols are set") {
val bucket = new Bucketizer()
.setInputCol("feature1")
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

// When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
assert(bucket.isBucketizeMultipleColumns() == false)
test("assert exception is thrown if both multi-column and single-column params are set") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test the other exclusive params (input cols and splits params) as per https://github.com/apache/spark/pull/19993/files#r159133936

val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("inputCols", Array("feature1", "feature2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "result1"),
("outputCols", Array("result1", "result2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("splits", Array(-0.5, 0.0, 0.5)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only comment I have is that I believe this line is not testing what you may think.

As I read the checkSingleVsMultiColumnParams method, in this test case it will throw the error, not because both splits and splitsArray are set, but rather because both inputCol & inputCols are unset.

Actually it applies to the line above too.

@jkbradley

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick actually it will fail for both reasons. We can add more test cases to check each of these two cases if you think it is needed.

("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))))
}
}

Expand Down
23 changes: 23 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ package org.apache.spark.ml.param
import java.io.{ByteArrayOutputStream, ObjectOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these are used any longer?

import org.apache.spark.ml.util.MyParams
import org.apache.spark.sql.Dataset

class ParamsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -430,4 +433,24 @@ object ParamsSuite extends SparkFunSuite {
require(copyReturnType === obj.getClass,
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}

/**
* Checks that the class throws an exception in case multiple exclusive params are set.
* The params to be checked are passed as arguments with their value.
*/
def testExclusiveParams(
model: Params,
dataset: Dataset[_],
paramsAndValues: (String, Any)*): Unit = {
val m = model.copy(ParamMap.empty)
paramsAndValues.foreach { case (paramName, paramValue) =>
m.set(m.getParam(paramName), paramValue)
}
val e = intercept[IllegalArgumentException] {
m match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
}
}
}