-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
Changes from 13 commits
8f3581c
bb0c0d2
9f56800
f593f5b
2ecdc73
26fe05e
64634b5
9872bfd
d0b8d06
b20fb91
09d652d
a0c0fed
25b9bd4
18bbf61
d9d25b0
8c162a3
7894609
ebc6d16
2bc5cb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ import scala.collection.mutable | |
|
||
import org.json4s._ | ||
import org.json4s.jackson.JsonMethods._ | ||
import org.slf4j.LoggerFactory | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.{DeveloperApi, Since} | ||
|
@@ -166,6 +167,8 @@ private[ml] object Param { | |
@DeveloperApi | ||
object ParamValidators { | ||
|
||
private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's switch this to use the Logging trait, to match other MLlib patterns. |
||
|
||
/** (private[param]) Default validation always return true */ | ||
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true | ||
|
||
|
@@ -249,6 +252,25 @@ object ParamValidators { | |
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => | ||
value.length > lowerBound | ||
} | ||
|
||
/** | ||
* Checks that only one of the params passed as arguments is set. If this is not true, an | ||
* `IllegalArgumentException` is raised. | ||
*/ | ||
def checkExclusiveParams(model: Params, params: String*): Unit = { | ||
val (existingParams, nonExistingParams) = params.partition(model.hasParam) | ||
if (nonExistingParams.nonEmpty) { | ||
val pronoun = if (nonExistingParams.size == 1) "It" else "They" | ||
LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " + | ||
s"exclusive params. $pronoun don't exist for the specified model the model.") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we need to check other exclusive params, e.g., def checkExclusiveParams(model: Params, params: String*): Unit = {
if (params.filter(model.isSet(_)).size > 1) {
val paramString = params.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, but more than one among them are set.")
}
}
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "inputCol", "splitsArray")
ParamValidators.checkExclusiveParams(this, "inputCols", "splits") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this method too in #20146. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can use that method once merged, thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if #20146 will get merged for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on #20146 (comment) from @WeichenXu123, I think #20146 cannot get merged for 2.3. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this method looks good to you, maybe you can just copy it from #20146 to use here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick @viirya in order to address https://github.com/apache/spark/pull/19993/files#r161682506, I was thinking to let this method as it is (just renaming it as per @viirya suggestion) and only adding an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @viirya's method is simpler and more general, so why not use it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) { | ||
val paramString = existingParams.mkString("`", "`, `", "`") | ||
throw new IllegalArgumentException(s"$paramString are exclusive, " + | ||
"but more than one among them are set.") | ||
} | ||
} | ||
} | ||
|
||
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setOutputCols(Array("result1", "result2")) | ||
.setSplitsArray(splits) | ||
|
||
assert(bucketizer1.isBucketizeMultipleColumns()) | ||
|
||
bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") | ||
BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), | ||
Seq("result1", "result2"), | ||
|
@@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setOutputCols(Array("result")) | ||
.setSplitsArray(Array(splits(0))) | ||
|
||
assert(bucketizer2.isBucketizeMultipleColumns()) | ||
|
||
withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { | ||
intercept[SparkException] { | ||
bucketizer2.transform(badDF1).collect() | ||
|
@@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setOutputCols(Array("result1", "result2")) | ||
.setSplitsArray(splits) | ||
|
||
assert(bucketizer.isBucketizeMultipleColumns()) | ||
|
||
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), | ||
Seq("result1", "result2"), | ||
Seq("expected1", "expected2")) | ||
|
@@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setOutputCols(Array("result1", "result2")) | ||
.setSplitsArray(splits) | ||
|
||
assert(bucketizer.isBucketizeMultipleColumns()) | ||
|
||
bucketizer.setHandleInvalid("keep") | ||
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), | ||
Seq("result1", "result2"), | ||
|
@@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setInputCols(Array("myInputCol")) | ||
.setOutputCols(Array("myOutputCol")) | ||
.setSplitsArray(Array(Array(0.1, 0.8, 0.9))) | ||
assert(t.isBucketizeMultipleColumns()) | ||
testDefaultReadWrite(t) | ||
} | ||
|
||
|
@@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
.setOutputCols(Array("result1", "result2")) | ||
.setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) | ||
|
||
assert(bucket.isBucketizeMultipleColumns()) | ||
|
||
val pl = new Pipeline() | ||
.setStages(Array(bucket)) | ||
.fit(df) | ||
|
@@ -401,15 +390,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
} | ||
} | ||
|
||
test("Both inputCol and inputCols are set") { | ||
val bucket = new Bucketizer() | ||
.setInputCol("feature1") | ||
.setOutputCol("result") | ||
.setSplits(Array(-0.5, 0.0, 0.5)) | ||
.setInputCols(Array("feature1", "feature2")) | ||
|
||
// When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. | ||
assert(bucket.isBucketizeMultipleColumns() == false) | ||
test("assert exception is thrown if both multi-column and single-column params are set") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also test the other exclusive params (input cols and splits params) as per https://github.com/apache/spark/pull/19993/files#r159133936 |
||
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") | ||
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), | ||
("inputCols", Array("feature1", "feature2"))) | ||
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "result1"), | ||
("outputCols", Array("result1", "result2"))) | ||
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("splits", Array(-0.5, 0.0, 0.5)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only comment I have is that I believe this line is not testing what you may think. As I read the checkSingleVsMultiColumnParams method, in this test case it will throw the error, not because both Actually it applies to the line above too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick actually it will fail for both reasons. We can add more test cases to check each of these two cases if you think it is needed. |
||
("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,11 @@ package org.apache.spark.ml.param | |
import java.io.{ByteArrayOutputStream, ObjectOutputStream} | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.ml.{Estimator, Transformer} | ||
import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think these are used any longer? |
||
import org.apache.spark.ml.util.MyParams | ||
import org.apache.spark.sql.Dataset | ||
|
||
class ParamsSuite extends SparkFunSuite { | ||
|
||
|
@@ -430,4 +433,26 @@ object ParamsSuite extends SparkFunSuite { | |
require(copyReturnType === obj.getClass, | ||
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") | ||
} | ||
|
||
/** | ||
* Checks that the class throws an exception in case multiple exclusive params are set | ||
* The params to be checked are passed as arguments with their value. | ||
* The checks are performed only if all the passed params are defined for the given model. | ||
*/ | ||
def testExclusiveParams(model: Params, dataset: Dataset[_], | ||
paramsAndValues: (String, Any)*): Unit = { | ||
val params = paramsAndValues.map(_._1) | ||
if (params.forall(model.hasParam)) { | ||
paramsAndValues.foreach { case (paramName, paramValue) => | ||
model.set(model.getParam(paramName), paramValue) | ||
} | ||
val e = intercept[IllegalArgumentException] { | ||
model match { | ||
case t: Transformer => t.transform(dataset) | ||
case e: Estimator[_] => e.fit(dataset) | ||
} | ||
} | ||
assert(e.getMessage.contains("are exclusive, but more than one")) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with trying to use a general method like this is that it's hard to capture model-specific requirements. This currently misses checking to make sure that exactly one (not just <= 1) of each pair is available, plus that all of the single-column OR all of the multi-column Params are available. (The same issue occurs in #20146 ) It will also be hard to check these items and account for defaults.
I'd argue that it's not worth trying to use generic checking functions here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my initial implementation (with @hhbyyh's comments) was more generic and checked what you said. After, @MLnick and @viirya asked to switch to a more generic approach which is the current you see. I'm fine with either of those, but I think we need to choose one way and go in that direction, otherwise we just loose time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I'll see if I can come up with something which is generic but handles these other checks.