Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-30377][ML] Make Regressors extend abstract class Regressor #27168

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand All @@ -44,9 +44,8 @@ import org.apache.spark.storage.StorageLevel
/**
* Params for accelerated failure time (AFT) regression.
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
private[regression] trait AFTSurvivalRegressionParams extends PredictorParams
with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with Logging {

/**
* Param for censor column name.
Expand Down Expand Up @@ -126,28 +125,16 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
*/
@Since("1.6.0")
class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with DefaultParamsWritable with Logging {
extends Regressor[Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel]
with AFTSurvivalRegressionParams with DefaultParamsWritable with Logging {

@Since("1.6.0")
def this() = this(Identifiable.randomUID("aftSurvReg"))

/** @group setParam */
@Since("1.6.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

/** @group setParam */
@Since("1.6.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("1.6.0")
def setCensorCol(value: String): this.type = set(censorCol, value)

/** @group setParam */
@Since("1.6.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
@Since("1.6.0")
def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
Expand Down Expand Up @@ -207,9 +194,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
}
}

@Since("2.0.0")
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)
@Since("3.0.0")
override def train(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr =>
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
Expand Down Expand Up @@ -281,7 +267,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale).setParent(this))
new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
}

@Since("1.6.0")
Expand Down Expand Up @@ -309,18 +295,11 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("1.6.0") val intercept: Double,
@Since("1.6.0") val scale: Double)
extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable {
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with MLWritable {

@Since("3.0.0")
lazy val numFeatures: Int = coefficients.size

/** @group setParam */
@Since("1.6.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

/** @group setParam */
@Since("1.6.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def numFeatures: Int = coefficients.size

/** @group setParam */
@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.StructType
*/
@Since("1.4.0")
class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
extends Regressor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeRegressorParams with DefaultParamsWritable {

@Since("1.4.0")
Expand Down Expand Up @@ -159,7 +159,7 @@ class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
extends RegressionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {

/** @group setParam */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ private[regression] trait FMRegressorParams extends FactorizationMachinesParams
@Since("3.0.0")
class FMRegressor @Since("3.0.0") (
@Since("3.0.0") override val uid: String)
extends Predictor[Vector, FMRegressor, FMRegressionModel]
extends Regressor[Vector, FMRegressor, FMRegressionModel]
with FactorizationMachines with FMRegressorParams with DefaultParamsWritable with Logging {

@Since("3.0.0")
Expand Down Expand Up @@ -454,7 +454,7 @@ class FMRegressionModel private[regression] (
@Since("3.0.0") val intercept: Double,
@Since("3.0.0") val linear: Vector,
@Since("3.0.0") val factors: Matrix)
extends PredictionModel[Vector, FMRegressionModel]
extends RegressionModel[Vector, FMRegressionModel]
with FMRegressorParams with MLWritable {

@Since("3.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import org.apache.spark.sql.types.StructType
*/
@Since("1.4.0")
class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
extends Regressor[Vector, GBTRegressor, GBTRegressionModel]
with GBTRegressorParams with DefaultParamsWritable with Logging {

@Since("1.4.0")
Expand Down Expand Up @@ -227,7 +227,7 @@ class GBTRegressionModel private[ml](
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
extends RegressionModel[Vector, GBTRegressionModel]
with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType
*/
@Since("1.4.0")
class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
extends Regressor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestRegressorParams with DefaultParamsWritable {

@Since("1.4.0")
Expand Down Expand Up @@ -170,7 +170,7 @@ class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
extends RegressionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

Expand Down
10 changes: 9 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,15 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations"),

// [SPARK-29348] Add observable metrics.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this"),

// [SPARK-30377][ML] Make AFTSurvivalRegression extend Regressor
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setFeaturesCol"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setPredictionCol"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setFeaturesCol"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setLabelCol"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol")
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing hierarchy of AFTSurvivalRegression/AFTSurvivalRegressionModelfrom extending Estimator/Model to Regressor/RegressionModel caused the following MiMa errors:

[error]  * method setFeaturesCol(java.lang.String)org.apache.spark.ml.regression.AFTSurvivalRegressionModel in class org.apache.spark.ml.regression.AFTSurvivalRegressionModel has a different result type in current version, where it is org.apache.spark.ml.PredictionModel rather than org.apache.spark.ml.regression.AFTSurvivalRegressionModel
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setFeaturesCol")
[error]  * method setPredictionCol(java.lang.String)org.apache.spark.ml.regression.AFTSurvivalRegressionModel in class org.apache.spark.ml.regression.AFTSurvivalRegressionModel has a different result type in current version, where it is org.apache.spark.ml.PredictionModel rather than org.apache.spark.ml.regression.AFTSurvivalRegressionModel
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setPredictionCol")
[error]  * method fit(org.apache.spark.sql.Dataset)org.apache.spark.ml.regression.AFTSurvivalRegressionModel in class org.apache.spark.ml.regression.AFTSurvivalRegression has a different result type in current version, where it is org.apache.spark.ml.Model rather than org.apache.spark.ml.regression.AFTSurvivalRegressionModel
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit")
[error]  * method setFeaturesCol(java.lang.String)org.apache.spark.ml.regression.AFTSurvivalRegression in class org.apache.spark.ml.regression.AFTSurvivalRegression has a different result type in current version, where it is org.apache.spark.ml.Predictor rather than org.apache.spark.ml.regression.AFTSurvivalRegression
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setFeaturesCol")
[error]  * method setLabelCol(java.lang.String)org.apache.spark.ml.regression.AFTSurvivalRegression in class org.apache.spark.ml.regression.AFTSurvivalRegression has a different result type in current version, where it is org.apache.spark.ml.Predictor rather than org.apache.spark.ml.regression.AFTSurvivalRegression
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setLabelCol")
[error]  * method setPredictionCol(java.lang.String)org.apache.spark.ml.regression.AFTSurvivalRegression in class org.apache.spark.ml.regression.AFTSurvivalRegression has a different result type in current version, where it is org.apache.spark.ml.Predictor rather than org.apache.spark.ml.regression.AFTSurvivalRegression
[error]    filter with: ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol")

There is not any API change, though.

Changing the hierarchy of extending Predictor/PredictionModel to Regressor/RegressionModel doesn't cause MiMa problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same story as last time eh, the type should be identical but it's not super obvious to Mima that this.type is the same as type T = [the model type].


// Exclude rules for 2.4.x
Expand Down
49 changes: 3 additions & 46 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,8 +1515,7 @@ def evaluateEachIteration(self, dataset, loss):
return self._call_java("evaluateEachIteration", dataset, loss)


class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasFitIntercept,
class _AFTSurvivalRegressionParams(_JavaPredictorParams, HasMaxIter, HasTol, HasFitIntercept,
HasAggregationDepth):
"""
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
Expand Down Expand Up @@ -1563,7 +1562,7 @@ def getQuantilesCol(self):


@inherit_doc
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
class AFTSurvivalRegression(JavaPredictor, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
Expand Down Expand Up @@ -1682,27 +1681,6 @@ def setMaxIter(self, value):
"""
return self._set(maxIter=value)

@since("1.6.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("1.6.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@since("1.6.0")
def setLabelCol(self, value):
"""
Sets the value of :py:attr:`labelCol`.
"""
return self._set(labelCol=value)

@since("1.6.0")
def setTol(self, value):
"""
Expand All @@ -1725,28 +1703,14 @@ def setAggregationDepth(self, value):
return self._set(aggregationDepth=value)


class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
class AFTSurvivalRegressionModel(JavaPredictionModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.

.. versionadded:: 1.6.0
"""

@since("3.0.0")
def setFeaturesCol(self, value):
"""
Sets the value of :py:attr:`featuresCol`.
"""
return self._set(featuresCol=value)

@since("3.0.0")
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@since("3.0.0")
def setQuantileProbabilities(self, value):
"""
Expand Down Expand Up @@ -1792,13 +1756,6 @@ def predictQuantiles(self, features):
"""
return self._call_java("predictQuantiles", features)

@since("2.0.0")
def predict(self, features):
"""
Predicted value
"""
return self._call_java("predict", features)


class _GeneralizedLinearRegressionParams(_JavaPredictorParams, HasFitIntercept, HasMaxIter,
HasTol, HasRegParam, HasWeightCol, HasSolver,
Expand Down