Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Jan 17, 2018
1 parent a0c0fed commit 25b9bd4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,13 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkMultiColumnParams(this)
if (isSet(inputCol) && isSet(splitsArray)) {
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray")
}
if (isSet(inputCols) && isSet(splits)) {
ParamValidators.raiseIncompatibleParamsException("inputCols", "splits")
}
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "splits", "splitsArray")

if (isSet(inputCols)) {
var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
Expand Down
32 changes: 16 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import scala.collection.mutable

import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable

/**
Expand Down Expand Up @@ -167,6 +167,8 @@ private[ml] object Param {
@DeveloperApi
object ParamValidators {

private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass)

/** (private[param]) Default validation always return true */
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true

Expand Down Expand Up @@ -252,24 +254,22 @@ object ParamValidators {
}

/**
* Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If
* this is not true, an `IllegalArgumentException` is raised.
* @param model
* Checks that only one of the params passed as arguments is set. If this is not true, an
* `IllegalArgumentException` is raised.
*/
private[spark] def checkMultiColumnParams(model: Params): Unit = {
model match {
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) =>
raiseIncompatibleParamsException("inputCols", "inputCol")
case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) =>
raiseIncompatibleParamsException("outputCols", "outputCol")
case _ =>
def checkExclusiveParams(model: Params, params: String*): Unit = {
val (existingParams, nonExistingParams) = params.partition(model.hasParam)
if (nonExistingParams.nonEmpty) {
val pronoun = if (nonExistingParams.size == 1) "It" else "They"
LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " +
s"exclusive params. $pronoun don't exist for the specified model the model.")
}
}

private[spark] def raiseIncompatibleParamsException(
paramName1: String,
paramName2: String): Unit = {
throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot both be set.")
if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) {
val paramString = existingParams.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, " +
"but more than one among them are set.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa

test("assert exception is thrown if both multi-column and single-column params are set") {
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df)
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
("inputCols", Array("feature1", "feature2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "result1"),
("outputCols", Array("result1", "result2")))
ParamsSuite.testExclusiveParams(new Bucketizer, df, ("splits", Array(-0.5, 0.0, 0.5)),
("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))))
}
}

Expand Down
39 changes: 10 additions & 29 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,43 +435,24 @@ object ParamsSuite extends SparkFunSuite {
}

/**
* Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and
* in case both `outputCols` and `outputCol` are set.
* These checks are performed only when the class extends respectively both `HasInputCols` and
* `HasInputCol` and both `HasOutputCols` and `HasOutputCol`.
*
* @param paramsClass The Class to be checked
* @param dataset A `Dataset` to use in the tests
* Checks that the class throws an exception in case multiple exclusive params are set
* The params to be checked are passed as arguments with their value.
* The checks are performed only if all the passed params are defined for the given model.
*/
def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = {
val cols = dataset.columns

if (paramsClass.isAssignableFrom(classOf[HasInputCols])
&& paramsClass.isAssignableFrom(classOf[HasInputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasInputCols].inputCols, cols)
model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0))
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
def testExclusiveParams(model: Params, dataset: Dataset[_],
paramsAndValues: (String, Any)*): Unit = {
val params = paramsAndValues.map(_._1)
if (params.forall(model.hasParam)) {
paramsAndValues.foreach { case (paramName, paramValue) =>
model.set(model.getParam(paramName), paramValue)
}
assert(e.getMessage.contains("cannot be both set"))
}

if (paramsClass.isAssignableFrom(classOf[HasOutputCols])
&& paramsClass.isAssignableFrom(classOf[HasOutputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasOutputCols].outputCols, cols)
model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0))
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
}
assert(e.getMessage.contains("cannot be both set"))
assert(e.getMessage.contains("are exclusive, but more than one"))
}
}
}

0 comments on commit 25b9bd4

Please sign in to comment.