Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Dec 19, 2017
1 parent 8f3581c commit bb0c0d2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
22 changes: 11 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
* only used for single column usage, and `splitsArray` is for multiple columns.
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
* columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
Expand Down Expand Up @@ -140,15 +140,15 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* by `inputCol`. A warning will be printed if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
if (isSet(inputCols) && isSet(inputCol) || isSet(inputCols) && isSet(outputCol) ||
isSet(inputCol) && isSet(outputCols)) {
throw new IllegalArgumentException("Both `inputCol` and `inputCols` are set, `Bucketizer` " +
"only supports setting either `inputCol` or `inputCols`.")
} else if (isSet(inputCols)) {
true
} else {
false
inputColsSanityCheck()
outputColsSanityCheck()
if (isSet(inputCol) && isSet(splitsArray)) {
raiseIncompatibleParamsException("inputCol", "splitsArray")
}
if (isSet(inputCols) && isSet(splits)) {
raiseIncompatibleParamsException("inputCols", "splits")
}
isSet(inputCols)
}

@Since("2.0.0")
Expand Down
8 changes: 8 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,14 @@ trait Params extends Identifiable with Serializable {
}
to
}

final def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
throw new IllegalArgumentException(
s"""
|Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports
|setting either `$paramName1` or `$paramName2`.
""".stripMargin)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,16 @@ trait HasInputCols extends Params {

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)

final def inputColsSanityCheck(): Unit = {
this match {
case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) =>
raiseIncompatibleParamsException("inputCols", "inputCol")
case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) =>
raiseIncompatibleParamsException("inputCols", "outputCol")
case _ =>
}
}
}

/**
Expand Down Expand Up @@ -272,6 +282,16 @@ trait HasOutputCols extends Params {

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)

final def outputColsSanityCheck(): Unit = {
this match {
case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) =>
raiseIncompatibleParamsException("outputCols", "inputCol")
case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) =>
raiseIncompatibleParamsException("outputCols", "outputCol")
case _ =>
}
}
}

/**
Expand Down

0 comments on commit bb0c0d2

Please sign in to comment.