Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Dec 21, 2017
1 parent 9872bfd commit d0b8d06
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.assertColOrCols(this)
ParamValidators.checkMultiColumnParams(this)
if (isSet(inputCol) && isSet(splitsArray)) {
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ object ParamValidators {
* this is not true, an `IllegalArgumentException` is raised.
* @param model
*/
private[spark] def assertColOrCols(model: Params): Unit = {
private[spark] def checkMultiColumnParams(model: Params): Unit = {
model match {
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) =>
raiseIncompatibleParamsException("inputCols", "inputCol")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}

test("assert exception is thrown is both multi-column and single-column params are set") {
ParamsSuite.checkMultiColumnParams(classOf[Bucketizer], spark)
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df)
}
}

Expand Down
28 changes: 12 additions & 16 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util.MyParams
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.Dataset

class ParamsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -441,24 +441,20 @@ object ParamsSuite extends SparkFunSuite {
* `HasInputCol` and both `HasOutputCols` and `HasOutputCol`.
*
* @param paramsClass The Class to be checked
* @param spark A `SparkSession` instance to use
* @param dataset A `Dataset` to use in the tests
*/
def checkMultiColumnParams(paramsClass: Class[_ <: Params], spark: SparkSession): Unit = {
import spark.implicits._
// create fake input Dataset
val feature1 = Array(-1.0, 0.0, 1.0)
val feature2 = Array(1.0, 0.0, -1.0)
val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2")
def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = {
val cols = dataset.columns

if (paramsClass.isAssignableFrom(classOf[HasInputCols])
&& paramsClass.isAssignableFrom(classOf[HasInputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasInputCols].inputCols, Array("feature1", "feature2"))
model.set(model.asInstanceOf[HasInputCol].inputCol, "features1")
model.set(model.asInstanceOf[HasInputCols].inputCols, cols)
model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0))
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(df)
case e: Estimator[_] => e.fit(df)
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
}
assert(e.getMessage.contains("cannot be both set"))
Expand All @@ -467,12 +463,12 @@ object ParamsSuite extends SparkFunSuite {
if (paramsClass.isAssignableFrom(classOf[HasOutputCols])
&& paramsClass.isAssignableFrom(classOf[HasOutputCol])) {
val model = paramsClass.newInstance()
model.set(model.asInstanceOf[HasOutputCols].outputCols, Array("result1", "result2"))
model.set(model.asInstanceOf[HasOutputCol].outputCol, "result1")
model.set(model.asInstanceOf[HasOutputCols].outputCols, cols)
model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0))
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(df)
case e: Estimator[_] => e.fit(df)
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
}
assert(e.getMessage.contains("cannot be both set"))
Expand Down

0 comments on commit d0b8d06

Please sign in to comment.