Skip to content

Commit

Permalink
Remove RandomForestClassificationModelParams and RandomForestRegressi…
Browse files Browse the repository at this point in the history
…onModelParams.
  • Loading branch information
yanboliang committed Nov 19, 2016
1 parent 13988b4 commit 7beb568
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
/** Trees in this ensemble. Warning: These have null parent Estimators. */
def trees: Array[M]

/**
* Number of trees in ensemble
*/
val getNumTrees: Int = trees.length

/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]

Expand Down
63 changes: 22 additions & 41 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,28 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}

/** Used for [[RandomForestParams]] */
private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)

/**
* The number of features to consider for splits at each tree node.
Expand Down Expand Up @@ -364,38 +384,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
}

/**
* Used for [[RandomForestParams]].
* This is separated out from [[RandomForestParams]] because of an issue with the
* `numTrees` method conflicting with this Param in the Estimator.
*/
private[ml] trait HasNumTrees extends Params {

/**
* Number of trees to train (>= 1).
* If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* TODO: Change to always do bootstrapping (simpler). SPARK-7130
* (default = 20)
* @group param
*/
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
ParamValidators.gtEq(1))

setDefault(numTrees -> 20)

/** @group setParam */
def setNumTrees(value: Int): this.type = set(numTrees, value)

/** @group getParam */
final def getNumTrees: Int = $(numTrees)
}

/**
* Parameters for Random Forest algorithms.
*/
private[ml] trait RandomForestParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with HasNumTrees

private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Expand All @@ -405,15 +393,9 @@ private[spark] object RandomForestParams {
private[ml] trait RandomForestClassifierParams
extends RandomForestParams with TreeClassifierParams

private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeClassifierParams

private[ml] trait RandomForestRegressorParams
extends RandomForestParams with TreeRegressorParams

private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
with HasFeatureSubsetStrategy with TreeRegressorParams

/**
* Parameters for Gradient-Boosted Tree algorithms.
*
Expand All @@ -437,7 +419,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)


/**
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
* estimator.
Expand Down

0 comments on commit 7beb568

Please sign in to comment.