Skip to content

Commit

Permalink
use ParamValidators
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Dec 19, 2017
1 parent f593f5b commit 2ecdc73
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* by `inputCol`. An exception will be thrown if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
inputColsSanityCheck()
outputColsSanityCheck()
ParamValidators.assertColOrCols(this)
if (isSet(inputCol) && isSet(splitsArray)) {
raiseIncompatibleParamsException("inputCol", "splitsArray")
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray")
}
if (isSet(inputCols) && isSet(splits)) {
raiseIncompatibleParamsException("inputCols", "splits")
ParamValidators.raiseIncompatibleParamsException("inputCols", "splits")
}
isSet(inputCols)
}
Expand Down
32 changes: 24 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable

/**
Expand Down Expand Up @@ -249,6 +250,29 @@ object ParamValidators {
def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
value.length > lowerBound
}

/**
* Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If
* this is not true, an `IllegalArgumentException` is raised.
* @param model
*/
def assertColOrCols(model: Params): Unit = {
model match {
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) =>
raiseIncompatibleParamsException("inputCols", "inputCol")
case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) =>
raiseIncompatibleParamsException("outputCols", "inputCol")
case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) =>
raiseIncompatibleParamsException("inputCols", "outputCol")
case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) =>
raiseIncompatibleParamsException("outputCols", "outputCol")
case _ =>
}
}

def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
throw new IllegalArgumentException(s"Both `$paramName1` and `$paramName2` are set.")
}
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
Expand Down Expand Up @@ -834,14 +858,6 @@ trait Params extends Identifiable with Serializable {
}
to
}

protected def raiseIncompatibleParamsException(paramName1: String, paramName2: String): Unit = {
throw new IllegalArgumentException(
s"""
|Both `$paramName1` and `$paramName2` are set, `${this.getClass.getName}` only supports
|setting either `$paramName1` or `$paramName2`.
""".stripMargin)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,6 @@ trait HasInputCols extends Params {

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)

final def inputColsSanityCheck(): Unit = {
this match {
case model: HasInputCol if isSet(inputCols) && isSet(model.inputCol) =>
raiseIncompatibleParamsException("inputCols", "inputCol")
case model: HasOutputCol if isSet(inputCols) && isSet(model.outputCol) =>
raiseIncompatibleParamsException("inputCols", "outputCol")
case _ =>
}
}
}

/**
Expand Down Expand Up @@ -282,16 +272,6 @@ trait HasOutputCols extends Params {

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)

final def outputColsSanityCheck(): Unit = {
this match {
case model: HasInputCol if isSet(outputCols) && isSet(model.inputCol) =>
raiseIncompatibleParamsException("outputCols", "inputCol")
case model: HasOutputCol if isSet(outputCols) && isSet(model.outputCol) =>
raiseIncompatibleParamsException("outputCols", "outputCol")
case _ =>
}
}
}

/**
Expand Down

0 comments on commit 2ecdc73

Please sign in to comment.