Skip to content

Commit

Permalink
add checkMultiColumnParams
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Dec 20, 2017
1 parent 64634b5 commit 9872bfd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,33 +401,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}

test("Both inputCol and inputCols are set") {
val feature1 = Array(-0.5, -0.3, 0.0, 0.2)
val feature2 = Array(-0.3, -0.2, 0.5, 0.0)
val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2")

val invalid1 = new Bucketizer()
.setInputCol("feature1")
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

val invalid2 = new Bucketizer()
.setOutputCol("result")
.setSplits(Array(-0.5, 0.0, 0.5))
.setInputCols(Array("feature1", "feature2"))

val invalid3 = new Bucketizer()
.setInputCol("feature1")
.setSplits(Array(-0.5, 0.0, 0.5))
.setOutputCols(Array("result1", "result2"))

Seq(invalid1, invalid2, invalid3).foreach { bucketizer =>
// When both inputCol/outputCol and inputCols/outputCols are set, we throw Exception.
intercept[IllegalArgumentException] {
bucketizer.transform(df)
}
}
test("assert exception is thrown is both multi-column and single-column params are set") {
ParamsSuite.checkMultiColumnParams(classOf[Bucketizer], spark)
}
}

Expand Down
48 changes: 48 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ package org.apache.spark.ml.param
import java.io.{ByteArrayOutputStream, ObjectOutputStream}

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util.MyParams
import org.apache.spark.sql.{Dataset, SparkSession}

class ParamsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -430,4 +433,49 @@ object ParamsSuite extends SparkFunSuite {
require(copyReturnType === obj.getClass,
s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
}

/**
* Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and
* in case both `outputCols` and `outputCol` are set.
* These checks are performed only whether the class extends respectively both `HasInputCols` and
* `HasInputCol` and both `HasOutputCols` and `HasOutputCol`.
*
* @param paramsClass The Class to be checked
* @param spark A `SparkSession` instance to use
*/
def checkMultiColumnParams(paramsClass: Class[_ <: Params], spark: SparkSession): Unit = {
import spark.implicits._
// create fake input Dataset
val feature1 = Array(-1.0, 0.0, 1.0)
val feature2 = Array(1.0, 0.0, -1.0)
val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2")

if (paramsClass.isAssignableFrom(classOf[HasInputCols])
&& paramsClass.isAssignableFrom(classOf[HasInputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasInputCols].inputCols, Array("feature1", "feature2"))
model.set(model.asInstanceOf[HasInputCol].inputCol, "features1")
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(df)
case e: Estimator[_] => e.fit(df)
}
}
assert(e.getMessage.contains("cannot be both set"))
}

if (paramsClass.isAssignableFrom(classOf[HasOutputCols])
&& paramsClass.isAssignableFrom(classOf[HasOutputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasOutputCols].outputCols, Array("result1", "result2"))
model.set(model.asInstanceOf[HasOutputCol].outputCol, "result1")
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(df)
case e: Estimator[_] => e.fit(df)
}
}
assert(e.getMessage.contains("cannot be both set"))
}
}
}

0 comments on commit 9872bfd

Please sign in to comment.