Skip to content

Commit

Permalink
[SPARK-15957][ML] RFormula supports forcing to index label
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
```RFormula``` will index label only when it is string type currently. If the label is numeric type and we use ```RFormula``` to present a classification model, there is no label attributes in label column metadata. The label attributes are useful when making prediction for classification, so we can force to index label by ```StringIndexer``` whether it is numeric or string type for classification. Then SparkR wrappers can extract label attributes from label column metadata successfully. This feature can help us to fix bug similar with [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153).
For regression, we will still to keep label as numeric type.
In this PR, we add a param ```indexLabel``` to control whether to force to index label for ```RFormula```.

## How was this patch tested?
Unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #13675 from yanboliang/spark-15957.
  • Loading branch information
yanboliang committed Oct 11, 2016
1 parent b515768 commit 19401a2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
29 changes: 26 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
Expand Down Expand Up @@ -104,6 +104,27 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/**
* Force to index label whether it is numeric or string type.
* Usually we index label only when it is string type.
* If the formula was used by classification algorithms,
* we can force to index label even it is numeric type by setting this param with true.
* Default: false.
* @group param
*/
@Since("2.1.0")
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)

/** @group getParam */
@Since("2.1.0")
def getForceIndexLabel: Boolean = $(forceIndexLabel)

/** @group setParam */
@Since("2.1.0")
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)

/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
Expand Down Expand Up @@ -167,8 +188,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)

if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) {
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
Expand All @@ -181,6 +202,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
require(!hasLabelCol(schema) || !$(forceIndexLabel),
"If label column already exists, forceIndexLabel can not be set with true.")
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}

test("label column already exists") {
test("label column already exists and forceIndexLabel was set with false") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y")
val model = formula.fit(original)
Expand All @@ -66,6 +66,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}

test("label column already exists but forceIndexLabel was set with true") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y").setForceIndexLabel(true)
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
formula.fit(original)
}
}

test("label column already exists but is not numeric type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = Seq((0, true), (2, false)).toDF("x", "y")
Expand Down Expand Up @@ -137,6 +145,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(result.collect() === expected.collect())
}

test("force to index label even it is numeric type") {
val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true)
val original = spark.createDataFrame(
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = spark.createDataFrame(
Seq(
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
).toDF("id", "a", "b", "features", "label")
assert(result.collect() === expected.collect())
}

test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
Expand Down

0 comments on commit 19401a2

Please sign in to comment.