From 7beb5688e83a49c9e0f2270fe413c2f2fff39ecd Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 19 Nov 2016 07:03:20 -0800 Subject: [PATCH] Remove RandomForestClassificationModelParams and RandomForestRegressionModelParams. --- .../ml/classification/GBTClassifier.scala | 5 ++ .../RandomForestClassifier.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 5 ++ .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeModels.scala | 5 -- .../org/apache/spark/ml/tree/treeParams.scala | 63 +++++++------------ 6 files changed, 34 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8f164e8c14bd..d7d4612e6d171 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -201,6 +201,11 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 3c784d5555a4a..f0035acb1d656 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index fa69d60836e68..0f6e178bf3184 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -183,6 +183,11 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index b123bc9360b9e..e9230b88e8cfe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -144,7 +144,7 @@ class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index d3cbc363799a5..0d6e9034e5ce4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { /** Trees in this ensemble. Warning: These have null parent Estimators. */ def trees: Array[M] - /** - * Number of trees in ensemble - */ - val getNumTrees: Int = trees.length - /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c1dfa50ed3ce4..311580040d1d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -317,8 +317,28 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { } } -/** Used for [[RandomForestParams]] */ -private[ml] trait HasFeatureSubsetStrategy extends Params { +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** @group setParam */ + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -364,38 +384,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase } -/** - * Used for [[RandomForestParams]]. - * This is separated out from [[RandomForestParams]] because of an issue with the - * `numTrees` method conflicting with this Param in the Estimator. - */ -private[ml] trait HasNumTrees extends Params { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with HasNumTrees - private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = @@ -405,15 +393,9 @@ private[spark] object RandomForestParams { private[ml] trait RandomForestClassifierParams extends RandomForestParams with TreeClassifierParams -private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeClassifierParams - private[ml] trait RandomForestRegressorParams extends RandomForestParams with TreeRegressorParams -private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeRegressorParams - /** * Parameters for Gradient-Boosted Tree algorithms. * @@ -437,7 +419,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each * estimator.