diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 5a2e07914b38f..99ed12d560303 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -22,7 +22,6 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -33,7 +32,6 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.StructType @@ -138,20 +136,6 @@ class DecisionTreeClassifier @Since("1.4.0") ( trees.head.asInstanceOf[DecisionTreeClassificationModel] } - /** (private[ml]) Train a decision tree on an RDD */ - private[ml] def train(data: RDD[LabeledPoint], - oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => - val instances = data.map(_.toInstance) - instr.logPipelineStage(this) - instr.logDataset(instances) - instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, - cacheNodeIds, checkpointInterval, impurity, seed) - val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, - featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) - - trees.head.asInstanceOf[DecisionTreeClassificationModel] - } - /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index c896b1589a936..aac3dbf6c5a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -70,7 +70,8 @@ private[spark] object BaggedPoint { if (numSubsamples == 1 && subsamplingRate == 1.0) { convertToBaggedRDDWithoutSampling(input, extractSampleWeight) } else { - convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, + extractSampleWeight, seed) } } } @@ -79,6 +80,7 @@ private[spark] object BaggedPoint { input: RDD[Datum], subsamplingRate: Double, numSubsamples: Int, + extractSampleWeight: (Datum => Double), seed: Long): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => // Use random seed = seed + partitionIndex + 1 to make generation reproducible. @@ -93,7 +95,7 @@ private[spark] object BaggedPoint { } subsampleIndex += 1 } - new BaggedPoint(instance, subsampleCounts) + new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 6a0301d87da9e..00f409f45d004 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -313,11 +313,12 @@ private[spark] object GradientBoostedTrees extends Logging { // Initialize tree timer.start("building tree 0") - val metadata = RandomForest.buildMetadata(input, treeStrategy, - numTrees = 1, featureSubsetStrategy) - val firstTreeModel = RandomForest.run(input, treeStrategy, numTrees = 1, - featureSubsetStrategy, seed = seed, instr = instr, - parentUID = None, precomputedMetadata = Some(metadata)) + val metadata = DecisionTreeMetadata.buildMetadata( + input.retag(classOf[Instance]), treeStrategy, numTrees = 1, + featureSubsetStrategy) + val firstTreeModel = RandomForest.runWithMetadata(input, metadata, treeStrategy, + numTrees = 1, featureSubsetStrategy, seed = seed, instr = instr, + parentUID = None) .head.asInstanceOf[DecisionTreeRegressionModel] val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel @@ -353,9 +354,9 @@ private[spark] object GradientBoostedTrees extends Logging { logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") - val model = RandomForest.run(data, treeStrategy, numTrees = 1, - featureSubsetStrategy, seed = seed + m, instr = None, - parentUID = None, precomputedMetadata = Some(metadata)) + val model = RandomForest.runWithMetadata(data, metadata, treeStrategy, + numTrees = 1, featureSubsetStrategy, seed = seed + m, + instr = None, parentUID = None) .head.asInstanceOf[DecisionTreeRegressionModel] timer.stop(s"building tree $m") // Update partial model diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index ac91c6fae5cb8..39fc224773327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -100,43 +100,24 @@ private[spark] object RandomForest extends Logging with Serializable { run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) } - def buildMetadata( - input: RDD[Instance], - strategy: OldStrategy, - numTrees: Int, - featureSubsetStrategy: String): DecisionTreeMetadata = { - val retaggedInput = input.retag(classOf[Instance]) - DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) - } - /** * Train a random forest. * * @param input Training data: RDD of `Instance` * @return an unweighted set of trees */ - def run( + def runWithMetadata( input: RDD[Instance], + metadata: DecisionTreeMetadata, strategy: OldStrategy, numTrees: Int, featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned - parentUID: Option[String] = None, - precomputedMetadata: Option[DecisionTreeMetadata] = None): Array[DecisionTreeModel] = { - + parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() - timer.start("total") - - timer.start("init") - - val retaggedInput = input.retag(classOf[Instance]) - val metadata = precomputedMetadata.getOrElse { - DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) - } - instr match { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) @@ -150,6 +131,12 @@ private[spark] object RandomForest extends Logging with Serializable { logInfo("weightedNumExamples: " + metadata.weightedNumExamples) } + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[Instance]) + // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplits") @@ -225,7 +212,7 @@ private[spark] object RandomForest extends Logging with Serializable { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") @@ -285,6 +272,32 @@ private[spark] object RandomForest extends Logging with Serializable { } } + /** + * Train a random forest. + * + * @param input Training data: RDD of `Instance` + * @return an unweighted set of trees + */ + def run( + input: RDD[Instance], + strategy: OldStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Long, + instr: Option[Instrumentation], + prune: Boolean = true, // exposed for testing only, real trees are always pruned + parentUID: Option[String] = None): Array[DecisionTreeModel] = { + val timer = new TimeTracker() + + timer.start("build metadata") + val metadata = DecisionTreeMetadata + .buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy) + timer.stop("build metadata") + + runWithMetadata(input, metadata, strategy, numTrees, featureSubsetStrategy, + seed, instr, prune, parentUID) + } + /** * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 06b4679741f9a..abeb4b5331acf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -493,30 +493,31 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("training with sample weights") { val df = binaryDataset val numClasses = 2 - val predEquals = (x: Double, y: Double) => x == y - // (maxIter, maxDepth) + // (maxIter, maxDepth, subsamplingRate, fractionInTol) val testParams = Seq( - (5, 5), - (5, 10) + (5, 5, 1.0, 0.99), + (5, 10, 1.0, 0.99), + (5, 10, 0.95, 0.9) ) - for ((maxIter, maxDepth) <- testParams) { + for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) { val estimator = new GBTClassifier() .setMaxIter(maxIter) .setMaxDepth(maxDepth) + .setSubsamplingRate(subsamplingRate) .setSeed(seed) .setMinWeightFractionPerNode(0.049) MLTestingUtils.testArbitrarilyScaledWeights[GBTClassificationModel, GBTClassifier](df.as[LabeledPoint], estimator, - MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7)) + MLTestingUtils.modelPredictionEquals(df, _ == _, tol)) MLTestingUtils.testOutliersWithSmallWeights[GBTClassificationModel, GBTClassifier](df.as[LabeledPoint], estimator, - numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8), + numClasses, MLTestingUtils.modelPredictionEquals(df, _ == _, tol), outlierRatio = 2) MLTestingUtils.testOversamplingVsWeighting[GBTClassificationModel, GBTClassifier](df.as[LabeledPoint], estimator, - MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed) + MLTestingUtils.modelPredictionEquals(df, _ == _, tol), seed) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b772a3b7737d0..35c0fc9b02b10 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -321,29 +321,31 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { test("training with sample weights") { val df = linearRegressionData val numClasses = 0 - // (maxIter, maxDepth) + // (maxIter, maxDepth, subsamplingRate, fractionInTol) val testParams = Seq( - (5, 5), - (5, 10) + (5, 5, 1.0, 0.98), + (5, 10, 1.0, 0.98), + (5, 10, 0.95, 0.6) ) - for ((maxIter, maxDepth) <- testParams) { + for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) { val estimator = new GBTRegressor() .setMaxIter(maxIter) .setMaxDepth(maxDepth) + .setSubsamplingRate(subsamplingRate) .setSeed(seed) .setMinWeightFractionPerNode(0.1) MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel, GBTRegressor](df.as[LabeledPoint], estimator, - MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95)) + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol)) MLTestingUtils.testOutliersWithSmallWeights[GBTRegressionModel, GBTRegressor](df.as[LabeledPoint], estimator, numClasses, - MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95), + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol), outlierRatio = 2) MLTestingUtils.testOversamplingVsWeighting[GBTRegressionModel, GBTRegressor](df.as[LabeledPoint], estimator, - MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 0.95), seed) + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol), seed) } }