Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
init
  • Loading branch information
zhengruifeng committed Jan 2, 2020
1 parent 90794b6 commit dac8f69
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -33,7 +32,6 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -138,20 +136,6 @@ class DecisionTreeClassifier @Since("1.4.0") (
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
val instances = data.map(_.toInstance)
instr.logPipelineStage(this)
instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ private[spark] object BaggedPoint {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples,
extractSampleWeight, seed)
}
}
}
Expand All @@ -79,6 +80,7 @@ private[spark] object BaggedPoint {
input: RDD[Datum],
subsamplingRate: Double,
numSubsamples: Int,
extractSampleWeight: (Datum => Double),
seed: Long): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
Expand All @@ -93,7 +95,7 @@ private[spark] object BaggedPoint {
}
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleCounts)
new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,12 @@ private[spark] object GradientBoostedTrees extends Logging {

// Initialize tree
timer.start("building tree 0")
val metadata = RandomForest.buildMetadata(input, treeStrategy,
numTrees = 1, featureSubsetStrategy)
val firstTreeModel = RandomForest.run(input, treeStrategy, numTrees = 1,
featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None, precomputedMetadata = Some(metadata))
val metadata = DecisionTreeMetadata.buildMetadata(
input.retag(classOf[Instance]), treeStrategy, numTrees = 1,
featureSubsetStrategy)
val firstTreeModel = RandomForest.runWithMetadata(input, metadata, treeStrategy,
numTrees = 1, featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
Expand Down Expand Up @@ -353,9 +354,9 @@ private[spark] object GradientBoostedTrees extends Logging {
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")

val model = RandomForest.run(data, treeStrategy, numTrees = 1,
featureSubsetStrategy, seed = seed + m, instr = None,
parentUID = None, precomputedMetadata = Some(metadata))
val model = RandomForest.runWithMetadata(data, metadata, treeStrategy,
numTrees = 1, featureSubsetStrategy, seed = seed + m,
instr = None, parentUID = None)
.head.asInstanceOf[DecisionTreeRegressionModel]
timer.stop(s"building tree $m")
// Update partial model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,43 +100,24 @@ private[spark] object RandomForest extends Logging with Serializable {
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
}

def buildMetadata(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
val retaggedInput = input.retag(classOf[Instance])
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
}

/**
* Train a random forest.
*
* @param input Training data: RDD of `Instance`
* @return an unweighted set of trees
*/
def run(
def runWithMetadata(
input: RDD[Instance],
metadata: DecisionTreeMetadata,
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None,
precomputedMetadata: Option[DecisionTreeMetadata] = None): Array[DecisionTreeModel] = {

parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()

timer.start("total")

timer.start("init")

val retaggedInput = input.retag(classOf[Instance])
val metadata = precomputedMetadata.getOrElse {
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
}

instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
Expand All @@ -150,6 +131,12 @@ private[spark] object RandomForest extends Logging with Serializable {
logInfo("weightedNumExamples: " + metadata.weightedNumExamples)
}

timer.start("total")

timer.start("init")

val retaggedInput = input.retag(classOf[Instance])

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplits")
Expand Down Expand Up @@ -225,7 +212,7 @@ private[spark] object RandomForest extends Logging with Serializable {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
RandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng)
// Sanity check (should never occur):
assert(nodesForGroup.nonEmpty,
s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
Expand Down Expand Up @@ -285,6 +272,32 @@ private[spark] object RandomForest extends Logging with Serializable {
}
}

/**
* Train a random forest.
*
* @param input Training data: RDD of `Instance`
* @return an unweighted set of trees
*/
def run(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()

timer.start("build metadata")
val metadata = DecisionTreeMetadata
.buildMetadata(input.retag(classOf[Instance]), strategy, numTrees, featureSubsetStrategy)
timer.stop("build metadata")

runWithMetadata(input, metadata, strategy, numTrees, featureSubsetStrategy,
seed, instr, prune, parentUID)
}

/**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,30 +493,31 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
test("training with sample weights") {
val df = binaryDataset
val numClasses = 2
val predEquals = (x: Double, y: Double) => x == y
// (maxIter, maxDepth)
// (maxIter, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(5, 5),
(5, 10)
(5, 5, 1.0, 0.99),
(5, 10, 1.0, 0.99),
(5, 10, 0.95, 0.9)
)

for ((maxIter, maxDepth) <- testParams) {
for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new GBTClassifier()
.setMaxIter(maxIter)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.049)

MLTestingUtils.testArbitrarilyScaledWeights[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
MLTestingUtils.modelPredictionEquals(df, _ == _, tol))
MLTestingUtils.testOutliersWithSmallWeights[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
numClasses, MLTestingUtils.modelPredictionEquals(df, _ == _, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[GBTClassificationModel,
GBTClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
MLTestingUtils.modelPredictionEquals(df, _ == _, tol), seed)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,29 +321,31 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
test("training with sample weights") {
val df = linearRegressionData
val numClasses = 0
// (maxIter, maxDepth)
// (maxIter, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(5, 5),
(5, 10)
(5, 5, 1.0, 0.98),
(5, 10, 1.0, 0.98),
(5, 10, 0.95, 0.6)
)

for ((maxIter, maxDepth) <- testParams) {
for ((maxIter, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new GBTRegressor()
.setMaxIter(maxIter)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.1)

MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95))
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol))
MLTestingUtils.testOutliersWithSmallWeights[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator, numClasses,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95),
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 0.95), seed)
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, tol), seed)
}
}

Expand Down

0 comments on commit dac8f69

Please sign in to comment.