Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator #20132

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
Expand Down Expand Up @@ -66,10 +66,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
def getDropLast: Boolean = $(dropLast)

protected def validateAndTransformSchema(
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
schema: StructType,
dropLast: Boolean,
keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
val existingFields = schema.fields

require(inputColNames.length == outputColNames.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand Down Expand Up @@ -197,6 +198,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
override def load(path: String): OneHotEncoderEstimator = super.load(path)
}

/**
* @param categorySizes Original number of categories for each feature being encoded.
* The array contains one value for each input column, in order.
*/
@Since("2.3.0")
class OneHotEncoderModel private[ml] (
@Since("2.3.0") override val uid: String,
Expand All @@ -205,60 +210,58 @@ class OneHotEncoderModel private[ml] (

import OneHotEncoderModel._

// Returns the category size for a given index with `dropLast` and `handleInvalid`
// Returns the category size for each index with `dropLast` and `handleInvalid`
// taken into account.
private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
private def getConfigedCategorySizes: Array[Int] = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID

if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as last category
// for invalid data.
orgCategorySize + 1
categorySizes.map(_ + 1)
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed.
orgCategorySize - 1
categorySizes.map(_ - 1)
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
// data is removed. Thus, it is the same as the plain number of categories.
orgCategorySize
categorySizes
}
}

private def encoder: UserDefinedFunction = {
val oneValue = Array(1.0)
val emptyValues = Array.empty[Double]
val emptyIndices = Array.empty[Int]
val dropLast = getDropLast
val handleInvalid = getHandleInvalid
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
val configedSizes = getConfigedCategorySizes
val localCategorySizes = categorySizes

// The udf performed on input data. The first parameter is the input value. The second
// parameter is the index of input.
udf { (label: Double, idx: Int) =>
val plainNumCategories = categorySizes(idx)
val size = configedCategorySize(plainNumCategories, idx)

if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative.")
} else if (label == size && dropLast && !keepInvalid) {
// When `dropLast` is true and `handleInvalid` is not "keep",
// the last category is removed.
Vectors.sparse(size, emptyIndices, emptyValues)
} else if (label >= plainNumCategories && keepInvalid) {
// When `handleInvalid` is "keep", encodes invalid data to last category (and removed
// if `dropLast` is true)
if (dropLast) {
Vectors.sparse(size, emptyIndices, emptyValues)
// parameter is the index in inputCols of the column being encoded.
udf { (label: Double, colIdx: Int) =>
val origCategorySize = localCategorySizes(colIdx)
// idx: index in vector of the single 1-valued element
val idx = if (label >= 0 && label < origCategorySize) {
label
} else {
if (keepInvalid) {
origCategorySize
} else {
Vectors.sparse(size, Array(size - 1), oneValue)
if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative. " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question. Since we don't allow negative value when fitting, should we allow it in transforming even handleInvalid is KEEP_INVALID?

Copy link
Member Author

@jkbradley jkbradley Jan 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point that it's unclear. I do think it'd be good to be robust during transform(). As far as fitting, I could see going either way (forcing data validation vs. being robust to small issues). I'd like to keep this strict during fitting (throwing errors) and robust during transform(), but let me know what you think.

I'll clarify this in the documentation.

s"To handle invalid values, set Param handleInvalid to " +
s"${OneHotEncoderEstimator.KEEP_INVALID}")
} else {
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
}
}
} else if (label < plainNumCategories) {
Vectors.sparse(size, Array(label.toInt), oneValue)
}

val size = configedSizes(colIdx)
if (idx < size) {
Vectors.sparse(size, Array(idx.toInt), Array(1.0))
} else {
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
}
}
}
Expand All @@ -282,7 +285,6 @@ class OneHotEncoderModel private[ml] (
@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)

require(inputColNames.length == categorySizes.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand All @@ -300,6 +302,7 @@ class OneHotEncoderModel private[ml] (
* account. Mismatched numbers will cause exception.
*/
private def verifyNumOfValues(schema: StructType): StructType = {
val configedSizes = getConfigedCategorySizes
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
Expand All @@ -308,9 +311,9 @@ class OneHotEncoderModel private[ml] (
// comparing with expected category number with `handleInvalid` and
// `dropLast` taken into account.
if (attrGroup.attributes.nonEmpty) {
val numCategories = configedCategorySize(categorySizes(idx), idx)
val numCategories = configedSizes(idx)
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
s"$numCategories categorical values for input column ${inputColName}, " +
s"$numCategories categorical values for input column $inputColName, " +
s"but the input column had metadata specifying ${attrGroup.size} values.")
}
}
Expand All @@ -322,7 +325,7 @@ class OneHotEncoderModel private[ml] (
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID

val encodedColumns = (0 until $(inputCols).length).map { idx =>
val encodedColumns = $(inputCols).indices.map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)

Expand Down