Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid #18582

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand All @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable {
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
with DefaultParamsWritable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about aligning with with extends?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it was suggested by IDEA.


@Since("1.4.0")
def this() = this(Identifiable.randomUID("bucketizer"))
Expand Down Expand Up @@ -84,17 +85,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* Default: "error"
* @group param
*/
// TODO: SPARK-18619 Make Bucketizer inherit from HasHandleInvalid.
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"invalid entries. Options are skip (filter out rows with invalid values), " +
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))

/** @group getParam */
@Since("2.1.0")
def getHandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("2.1.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
Expand All @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.StructType
* Params for [[QuantileDiscretizer]].
*/
private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol {
with HasHandleInvalid with HasInputCol with HasOutputCol {

/**
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
Expand Down Expand Up @@ -72,18 +72,13 @@ private[feature] trait QuantileDiscretizerBase extends Params
* Default: "error"
* @group param
*/
// TODO: SPARK-18619 Make QuantileDiscretizer inherit from HasHandleInvalid.
@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"invalid entries. Options are skip (filter out rows with invalid values), " +
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

/** @group getParam */
@Since("2.1.0")
def getHandleInvalid: String = $(handleInvalid)

}

/**
Expand Down
13 changes: 5 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineS
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -108,7 +108,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
@Experimental
@Since("1.5.0")
class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
extends Estimator[RFormulaModel] with RFormulaBase with HasHandleInvalid
with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("rFormula"))
Expand Down Expand Up @@ -141,8 +142,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
* @group param
*/
@Since("2.3.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
"invalid data (unseen labels or NULL values). " +
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
Expand All @@ -152,10 +153,6 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/** @group getParam */
@Since("2.3.0")
def getHandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
Expand All @@ -36,7 +36,8 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is better to break the line at extends.

with HasOutputCol {

/**
* Param for how to handle invalid data (unseen labels or NULL values).
Expand All @@ -47,18 +48,14 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* @group param
*/
@Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
"invalid data (unseen labels or NULL values). " +
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

/** @group getParam */
@Since("1.6.0")
def getHandleInvalid: String = $(handleInvalid)

/**
* Param for how to order labels of string column. The first label after ordering is assigned
* an index of 0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
"will filter out rows with bad values), or error (which will throw an error). More " +
"options may be added later",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))", finalFields = false),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ private[ml] trait HasHandleInvalid extends Params {
* Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.
* @group param
*/
final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))

/** @group getParam */
final def getHandleInvalid: String = $(handleInvalid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
*
* @group param
*/
@Since("2.3.0")
@Since("2.0.0")
final override val solver: Param[String] = new Param[String](this, "solver",
"The solver algorithm for optimization. Supported options: " +
s"${supportedSolvers.mkString(", ")}. (Default irls)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
*
* @group param
*/
@Since("2.3.0")
@Since("1.6.0")
final override val solver: Param[String] = new Param[String](this, "solver",
"The solver algorithm for optimization. Supported options: " +
s"${supportedSolvers.mkString(", ")}. (Default auto)",
Expand Down Expand Up @@ -194,7 +194,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
*/
@Since("1.6.0")
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> AUTO)
setDefault(solver -> Auto)

/**
* Suggested depth for treeAggregate (greater than or equal to 2).
Expand Down Expand Up @@ -224,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
instr.logNumFeatures(numFeatures)

if (($(solver) == AUTO &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
if (($(solver) == Auto &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) {
// For low dimensional data, WeightedLeastSquares is more efficient since the
// training algorithm only requires one pass through the data. (SPARK-10668)

Expand Down Expand Up @@ -460,16 +460,16 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES

/** String name for "auto". */
private[regression] val AUTO = "auto"
private[regression] val Auto = "auto"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is solver related with handleInvalid? It seems a little confused.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy It's not related to this PR, just addressed other small issues by the way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy We had some offline discussion, and decided to fix some small issues in this PR.


/** String name for "normal". */
private[regression] val NORMAL = "normal"
private[regression] val Normal = "normal"

/** String name for "l-bfgs". */
private[regression] val LBFGS = "l-bfgs"

/** Set of solvers that LinearRegression supports. */
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
private[regression] val supportedSolvers = Array(Auto, Normal, LBFGS)
}

/**
Expand Down
60 changes: 24 additions & 36 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)


@inherit_doc
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
JavaMLReadable, JavaMLWritable):
"""
Maps a column of continuous features to a column of feature buckets.

Expand Down Expand Up @@ -398,20 +399,6 @@ def getSplits(self):
"""
return self.getOrDefault(self.splits)

@since("2.1.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.1.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)


@inherit_doc
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down Expand Up @@ -1623,7 +1610,8 @@ def getDegree(self):


@inherit_doc
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -1743,20 +1731,6 @@ def getRelativeError(self):
"""
return self.getOrDefault(self.relativeError)

@since("2.1.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.1.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)

def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
Expand Down Expand Up @@ -2977,7 +2951,8 @@ def explainedVariance(self):


@inherit_doc
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid,
JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental

Expand Down Expand Up @@ -3020,6 +2995,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
True
>>> loadedRF.getLabelCol() == rf.getLabelCol()
True
>>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()
True
>>> str(loadedRF)
'RFormula(y ~ x + s) (uid=...)'
>>> modelPath = temp_path + "/rFormulaModel"
Expand Down Expand Up @@ -3058,26 +3035,37 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
"RFormula drops the same category as R when encoding strings.",
typeConverter=TypeConverters.toString)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, handleInvalid has been declared in HasHandleInvalid interface, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@facaiy We have to override it to keep the pydoc in line with scala side.

"Options are 'skip' (filter out rows with invalid values), " +
"'error' (throw an error), or 'keep' (put invalid data in a special " +
"additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, formula=None, featuresCol="features", labelCol="label",
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
handleInvalid="error"):
"""
__init__(self, formula=None, featuresCol="features", labelCol="label", \
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
handleInvalid="error")
"""
super(RFormula, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
handleInvalid="error")
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.5.0")
def setParams(self, formula=None, featuresCol="features", labelCol="label",
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
handleInvalid="error"):
"""
setParams(self, formula=None, featuresCol="features", labelCol="label", \
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
handleInvalid="error")
Sets params for RFormula.
"""
kwargs = self._input_kwargs
Expand Down