-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18518][ML] HasSolver supports override #16028
Changes from all commits
b907314
39ea8e1
6bb1daf
ebfa9c0
d15ea65
51cb43b
2e280d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,7 +143,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty | ||
} | ||
|
||
import GeneralizedLinearRegression._ | ||
/** | ||
* The solver algorithm for optimization. | ||
* Supported options: "irls" (iteratively reweighted least squares). | ||
* Default: "irls" | ||
* | ||
* @group param | ||
*/ | ||
@Since("2.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2.3.0 -> 2.0.0, please fix it in #17995 . |
||
final override val solver: Param[String] = new Param[String](this, "solver", | ||
"The solver algorithm for optimization. Supported options: " + | ||
s"${supportedSolvers.mkString(", ")}. (Default irls)", | ||
ParamValidators.inArray[String](supportedSolvers)) | ||
|
||
@Since("2.0.0") | ||
override def validateAndTransformSchema( | ||
|
@@ -314,7 +325,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
*/ | ||
@Since("2.0.0") | ||
def setSolver(value: String): this.type = set(solver, value) | ||
setDefault(solver -> "irls") | ||
setDefault(solver -> IRLS) | ||
|
||
/** | ||
* Sets the link prediction (linear predictor) column name. | ||
|
@@ -400,6 +411,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log | ||
) | ||
|
||
/** String name for "irls" (iteratively reweighted least squares) solver. */ | ||
private[regression] val IRLS = "irls" | ||
|
||
/** Set of solvers that GeneralizedLinearRegression supports. */ | ||
private[regression] val supportedSolvers = Array(IRLS) | ||
|
||
/** Set of family names that GeneralizedLinearRegression supports. */ | ||
private[regression] lazy val supportedFamilyNames = | ||
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares | |
import org.apache.spark.ml.PredictorParams | ||
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator | ||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} | ||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} | ||
import org.apache.spark.ml.param.shared._ | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.mllib.evaluation.RegressionMetrics | ||
|
@@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel | |
private[regression] trait LinearRegressionParams extends PredictorParams | ||
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol | ||
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver | ||
with HasAggregationDepth | ||
with HasAggregationDepth { | ||
|
||
import LinearRegression._ | ||
|
||
/** | ||
* The solver algorithm for optimization. | ||
* Supported options: "l-bfgs", "normal" and "auto". | ||
* Default: "auto" | ||
* | ||
* @group param | ||
*/ | ||
@Since("2.3.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2.3.0 -> 1.6.0 |
||
final override val solver: Param[String] = new Param[String](this, "solver", | ||
"The solver algorithm for optimization. Supported options: " + | ||
s"${supportedSolvers.mkString(", ")}. (Default auto)", | ||
ParamValidators.inArray[String](supportedSolvers)) | ||
} | ||
|
||
/** | ||
* Linear regression. | ||
|
@@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
extends Regressor[Vector, LinearRegression, LinearRegressionModel] | ||
with LinearRegressionParams with DefaultParamsWritable with Logging { | ||
|
||
import LinearRegression._ | ||
|
||
@Since("1.4.0") | ||
def this() = this(Identifiable.randomUID("linReg")) | ||
|
||
|
@@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* @group setParam | ||
*/ | ||
@Since("1.6.0") | ||
def setSolver(value: String): this.type = { | ||
require(Set("auto", "l-bfgs", "normal").contains(value), | ||
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") | ||
set(solver, value) | ||
} | ||
setDefault(solver -> "auto") | ||
def setSolver(value: String): this.type = set(solver, value) | ||
setDefault(solver -> AUTO) | ||
|
||
/** | ||
* Suggested depth for treeAggregate (greater than or equal to 2). | ||
|
@@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) | ||
instr.logNumFeatures(numFeatures) | ||
|
||
if (($(solver) == "auto" && | ||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { | ||
if (($(solver) == AUTO && | ||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { | ||
// For low dimensional data, WeightedLeastSquares is more efficient since the | ||
// training algorithm only requires one pass through the data. (SPARK-10668) | ||
|
||
|
@@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { | |
*/ | ||
@Since("2.1.0") | ||
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES | ||
|
||
/** String name for "auto". */ | ||
private[regression] val AUTO = "auto" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am probably forgetting, but I thought we decided against introducing constants in a similar situation last month? not sure if it's the same context There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I modified here to follow the way in the companion object of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In MLlib convention, |
||
|
||
/** String name for "normal". */ | ||
private[regression] val NORMAL = "normal" | ||
|
||
/** String name for "l-bfgs". */ | ||
private[regression] val LBFGS = "l-bfgs" | ||
|
||
/** Set of solvers that LinearRegression supports. */ | ||
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS) | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? is it for consistency? I usually like not importing class members unless it significantly improves readability since it slightly obscures the source, but, this is a side point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for consistency.