-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid #18582
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Since | |
import org.apache.spark.ml.{Estimator, Model, Transformer} | ||
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
import org.apache.spark.sql.functions._ | ||
|
@@ -36,7 +36,8 @@ import org.apache.spark.util.collection.OpenHashMap | |
/** | ||
* Base trait for [[StringIndexer]] and [[StringIndexerModel]]. | ||
*/ | ||
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { | ||
private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it is better to break the line at |
||
with HasOutputCol { | ||
|
||
/** | ||
* Param for how to handle invalid data (unseen labels or NULL values). | ||
|
@@ -47,18 +48,14 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha | |
* @group param | ||
*/ | ||
@Since("1.6.0") | ||
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + | ||
"invalid data (unseen labels or NULL values). " + | ||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"How to handle invalid data (unseen labels or NULL values). " + | ||
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " + | ||
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).", | ||
ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) | ||
|
||
setDefault(handleInvalid, StringIndexer.ERROR_INVALID) | ||
|
||
/** @group getParam */ | ||
@Since("1.6.0") | ||
def getHandleInvalid: String = $(handleInvalid) | ||
|
||
/** | ||
* Param for how to order labels of string column. The first label after ordering is assigned | ||
* an index of 0. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams | |
* | ||
* @group param | ||
*/ | ||
@Since("2.3.0") | ||
@Since("1.6.0") | ||
final override val solver: Param[String] = new Param[String](this, "solver", | ||
"The solver algorithm for optimization. Supported options: " + | ||
s"${supportedSolvers.mkString(", ")}. (Default auto)", | ||
|
@@ -194,7 +194,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
*/ | ||
@Since("1.6.0") | ||
def setSolver(value: String): this.type = set(solver, value) | ||
setDefault(solver -> AUTO) | ||
setDefault(solver -> Auto) | ||
|
||
/** | ||
* Suggested depth for treeAggregate (greater than or equal to 2). | ||
|
@@ -224,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) | ||
instr.logNumFeatures(numFeatures) | ||
|
||
if (($(solver) == AUTO && | ||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { | ||
if (($(solver) == Auto && | ||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == Normal) { | ||
// For low dimensional data, WeightedLeastSquares is more efficient since the | ||
// training algorithm only requires one pass through the data. (SPARK-10668) | ||
|
||
|
@@ -460,16 +460,16 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { | |
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES | ||
|
||
/** String name for "auto". */ | ||
private[regression] val AUTO = "auto" | ||
private[regression] val Auto = "auto" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @facaiy It's not related to this PR, just addressed other small issues by the way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @facaiy We had some offline discussion, and decided to fix some small issues in this PR. |
||
|
||
/** String name for "normal". */ | ||
private[regression] val NORMAL = "normal" | ||
private[regression] val Normal = "normal" | ||
|
||
/** String name for "l-bfgs". */ | ||
private[regression] val LBFGS = "l-bfgs" | ||
|
||
/** Set of solvers that LinearRegression supports. */ | ||
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS) | ||
private[regression] val supportedSolvers = Array(Auto, Normal, LBFGS) | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -314,7 +314,8 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) | |
|
||
|
||
@inherit_doc | ||
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, | ||
JavaMLReadable, JavaMLWritable): | ||
""" | ||
Maps a column of continuous features to a column of feature buckets. | ||
|
||
|
@@ -398,20 +399,6 @@ def getSplits(self): | |
""" | ||
return self.getOrDefault(self.splits) | ||
|
||
@since("2.1.0") | ||
def setHandleInvalid(self, value): | ||
""" | ||
Sets the value of :py:attr:`handleInvalid`. | ||
""" | ||
return self._set(handleInvalid=value) | ||
|
||
@since("2.1.0") | ||
def getHandleInvalid(self): | ||
""" | ||
Gets the value of :py:attr:`handleInvalid` or its default value. | ||
""" | ||
return self.getOrDefault(self.handleInvalid) | ||
|
||
|
||
@inherit_doc | ||
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
|
@@ -1623,7 +1610,8 @@ def getDegree(self): | |
|
||
|
||
@inherit_doc | ||
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, | ||
JavaMLReadable, JavaMLWritable): | ||
""" | ||
.. note:: Experimental | ||
|
||
|
@@ -1743,20 +1731,6 @@ def getRelativeError(self): | |
""" | ||
return self.getOrDefault(self.relativeError) | ||
|
||
@since("2.1.0") | ||
def setHandleInvalid(self, value): | ||
""" | ||
Sets the value of :py:attr:`handleInvalid`. | ||
""" | ||
return self._set(handleInvalid=value) | ||
|
||
@since("2.1.0") | ||
def getHandleInvalid(self): | ||
""" | ||
Gets the value of :py:attr:`handleInvalid` or its default value. | ||
""" | ||
return self.getOrDefault(self.handleInvalid) | ||
|
||
def _create_model(self, java_model): | ||
""" | ||
Private method to convert the java_model to a Python model. | ||
|
@@ -2977,7 +2951,8 @@ def explainedVariance(self): | |
|
||
|
||
@inherit_doc | ||
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): | ||
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid, | ||
JavaMLReadable, JavaMLWritable): | ||
""" | ||
.. note:: Experimental | ||
|
||
|
@@ -3020,6 +2995,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM | |
True | ||
>>> loadedRF.getLabelCol() == rf.getLabelCol() | ||
True | ||
>>> loadedRF.getHandleInvalid() == rf.getHandleInvalid() | ||
True | ||
>>> str(loadedRF) | ||
'RFormula(y ~ x + s) (uid=...)' | ||
>>> modelPath = temp_path + "/rFormulaModel" | ||
|
@@ -3058,26 +3035,37 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM | |
"RFormula drops the same category as R when encoding strings.", | ||
typeConverter=TypeConverters.toString) | ||
|
||
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @facaiy We have to override it to keep the pydoc in line with scala side. |
||
"Options are 'skip' (filter out rows with invalid values), " + | ||
"'error' (throw an error), or 'keep' (put invalid data in a special " + | ||
"additional bucket, at index numLabels).", | ||
typeConverter=TypeConverters.toString) | ||
|
||
@keyword_only | ||
def __init__(self, formula=None, featuresCol="features", labelCol="label", | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", | ||
handleInvalid="error"): | ||
""" | ||
__init__(self, formula=None, featuresCol="features", labelCol="label", \ | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \ | ||
handleInvalid="error") | ||
""" | ||
super(RFormula, self).__init__() | ||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) | ||
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") | ||
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", | ||
handleInvalid="error") | ||
kwargs = self._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
@since("1.5.0") | ||
def setParams(self, formula=None, featuresCol="features", labelCol="label", | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", | ||
handleInvalid="error"): | ||
""" | ||
setParams(self, formula=None, featuresCol="features", labelCol="label", \ | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") | ||
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \ | ||
handleInvalid="error") | ||
Sets params for RFormula. | ||
""" | ||
kwargs = self._input_kwargs | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about aligning
with
withextends
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it was suggested by IDEA.