-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
Changes from 5 commits
8f3581c
bb0c0d2
9f56800
f593f5b
2ecdc73
26fe05e
64634b5
9872bfd
d0b8d06
b20fb91
09d652d
a0c0fed
25b9bd4
18bbf61
d9d25b0
8c162a3
7894609
ebc6d16
2bc5cb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,9 +34,9 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} | |
/** | ||
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, | ||
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that | ||
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and | ||
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is | ||
* only used for single column usage, and `splitsArray` is for multiple columns. | ||
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The | ||
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple | ||
* columns. | ||
*/ | ||
@Since("1.4.0") | ||
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
|
@@ -137,18 +137,17 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
/** | ||
* Determines whether this `Bucketizer` is going to map multiple columns. If and only if | ||
* `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified | ||
* by `inputCol`. A warning will be printed if both are set. | ||
* by `inputCol`. An exception will be thrown if both are set. | ||
*/ | ||
private[feature] def isBucketizeMultipleColumns(): Boolean = { | ||
if (isSet(inputCols) && isSet(inputCol)) { | ||
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + | ||
"`Bucketizer` only map one column specified by `inputCol`") | ||
false | ||
} else if (isSet(inputCols)) { | ||
true | ||
} else { | ||
false | ||
ParamValidators.assertColOrCols(this) | ||
if (isSet(inputCol) && isSet(splitsArray)) { | ||
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray") | ||
} | ||
if (isSet(inputCols) && isSet(splits)) { | ||
ParamValidators.raiseIncompatibleParamsException("inputCols", "splits") | ||
} | ||
isSet(inputCols) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems superfluous to how have a separate method for this |
||
} | ||
|
||
@Since("2.0.0") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ | |
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.{DeveloperApi, Since} | ||
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector} | ||
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util.Identifiable | ||
|
||
/** | ||
|
@@ -249,6 +250,29 @@ object ParamValidators { | |
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => | ||
value.length > lowerBound | ||
} | ||
|
||
/** | ||
* Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If | ||
* this is not true, an `IllegalArgumentException` is raised. | ||
* @param model | ||
*/ | ||
def assertColOrCols(model: Params): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private[spark] |
||
model match { | ||
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => | ||
raiseIncompatibleParamsException("inputCols", "inputCol") | ||
case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may not necessarily be an error for some classes, but we can keep it for now. |
||
raiseIncompatibleParamsException("outputCols", "inputCol") | ||
case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to miss it, but I just found that FeatureHasher has both InputCols and OutputCol. |
||
raiseIncompatibleParamsException("inputCols", "outputCol") | ||
case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => | ||
raiseIncompatibleParamsException("outputCols", "outputCol") | ||
case _ => | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we need to check other exclusive params, e.g., def checkExclusiveParams(model: Params, params: String*): Unit = {
if (params.filter(model.isSet(_)).size > 1) {
val paramString = params.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, but more than one among them are set.")
}
}
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "inputCol", "splitsArray")
ParamValidators.checkExclusiveParams(this, "inputCols", "splits") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this method too in #20146. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can use that method once merged, thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if #20146 will get merged for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on #20146 (comment) from @WeichenXu123, I think #20146 cannot get merged for 2.3. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this method looks good to you, maybe you can just copy it from #20146 to use here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick @viirya in order to address https://github.com/apache/spark/pull/19993/files#r161682506, I was thinking to let this method as it is (just renaming it as per @viirya suggestion) and only adding an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @viirya's method is simpler and more general, so why not use it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
|
||
def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. private[spark] |
||
throw new IllegalArgumentException(s"Both `$paramName1` and `$paramName2` are set.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error message can be more straight forward. e.g. |
||
} | ||
} | ||
|
||
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed
isBucketizeMultipleColumns
is invoked in many places and maybe we can put the checks in other places like transformSchema. It also makes the code consistent with function name.