Skip to content

Commit

Permalink
Updates based on code review. 1 major change: persisting to memory + …
Browse files Browse the repository at this point in the history
…disk, not just memory.

Details:

DecisionTree
* Changed: .cache() -> .persist(StorageLevel.MEMORY_AND_DISK)
** This gave major performance improvements on small tests.  E.g., 500K examples, 500 features, depth 5, on MacBook, took 292 sec with cache() and 112 when using disk as well.
* Change for to while loops
* Small cleanups

TimeTracker
* Removed useless timing in DecisionTree

TreePoint
* Renamed features to binnedFeatures
  • Loading branch information
jkbradley committed Aug 15, 2014
1 parent 2d2aaaf commit 6b5651e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 53 deletions.
49 changes: 24 additions & 25 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.tree


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand All @@ -32,6 +31,7 @@ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom


Expand Down Expand Up @@ -59,11 +59,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.start("total")

// Cache input RDD for speedup during multiple passes.
timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)
timer.stop("init")

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
Expand All @@ -73,9 +72,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
timer.stop("findSplitsBins")
logDebug("numBins = " + numBins)

timer.start("init")
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
timer.stop("init")
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
.persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
val maxDepth = strategy.maxDepth
Expand All @@ -90,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
val numFeatures = treeInput.take(1)(0).features.size
val numFeatures = treeInput.take(1)(0).binnedFeatures.size

// Calculate level for single group construction

Expand All @@ -110,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
logDebug("max level for single group = " + maxLevelForSingleGroup)

timer.stop("init")

/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
Expand All @@ -126,7 +127,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("level = " + level)
logDebug("#####################################")


// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
Expand Down Expand Up @@ -167,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo

timer.stop("total")

logDebug("Internal timing for DecisionTree:")
logDebug(s"$timer")
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

new DecisionTreeModel(topNode, strategy.algo)
}
Expand Down Expand Up @@ -226,7 +226,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
}
}


object DecisionTree extends Serializable with Logging {

/**
Expand Down Expand Up @@ -536,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("numNodes = " + numNodes)

// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.size
val numFeatures = input.first().binnedFeatures.size
logDebug("numFeatures = " + numFeatures)

// numBins: Number of bins = 1 + number of possible splits
Expand Down Expand Up @@ -578,12 +577,12 @@ object DecisionTree extends Serializable with Logging {
}

// Apply each filter and check sample validity. Return false when invalid condition found.
for (filter <- parentFilters) {
parentFilters.foreach { filter =>
val featureIndex = filter.split.feature
val comparison = filter.comparison
val isFeatureContinuous = filter.split.featureType == Continuous
if (isFeatureContinuous) {
val binId = treePoint.features(featureIndex)
val binId = treePoint.binnedFeatures(featureIndex)
val bin = bins(featureIndex)(binId)
val featureValue = bin.highSplit.threshold
val threshold = filter.split.threshold
Expand All @@ -598,9 +597,9 @@ object DecisionTree extends Serializable with Logging {
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
val featureValue = if (isUnorderedFeature) {
treePoint.features(featureIndex)
treePoint.binnedFeatures(featureIndex)
} else {
val binId = treePoint.features(featureIndex)
val binId = treePoint.binnedFeatures(featureIndex)
bins(featureIndex)(binId).category
}
val containsFeature = filter.split.categories.contains(featureValue)
Expand Down Expand Up @@ -648,9 +647,8 @@ object DecisionTree extends Serializable with Logging {
arr(shift) = InvalidBinIndex
} else {
var featureIndex = 0
// TODO: Vectorize this
while (featureIndex < numFeatures) {
arr(shift + featureIndex) = treePoint.features(featureIndex)
arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
featureIndex += 1
}
}
Expand All @@ -660,9 +658,8 @@ object DecisionTree extends Serializable with Logging {
}

// Find feature bins for all nodes at a level.
timer.start("findBinsForLevel")
timer.start("aggregation")
val binMappedRDD = input.map(x => findBinsForLevel(x))
timer.stop("findBinsForLevel")

/**
* Increment aggregate in location for (node, feature, bin, label).
Expand Down Expand Up @@ -907,13 +904,11 @@ object DecisionTree extends Serializable with Logging {
combinedAggregate
}


// Calculate bin aggregates.
timer.start("binAggregates")
val binAggregates = {
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
}
timer.stop("binAggregates")
timer.stop("aggregation")
logDebug("binAggregates.length = " + binAggregates.length)

/**
Expand Down Expand Up @@ -1225,12 +1220,16 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

for (featureIndex <- 0 until numFeatures) {
var featureIndex = 0
while (featureIndex < numFeatures) {
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
for (splitIndex <- 0 until numSplitsForFeature) {
var splitIndex = 0
while (splitIndex < numSplitsForFeature) {
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
splitIndex, rightNodeAgg, nodeImpurity)
splitIndex += 1
}
featureIndex += 1
}
gains
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import org.apache.spark.annotation.Experimental
* Time tracker implementation which holds labeled timers.
*/
@Experimental
private[tree]
class TimeTracker extends Serializable {
private[tree] class TimeTracker extends Serializable {

private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()

Expand All @@ -36,24 +35,24 @@ class TimeTracker extends Serializable {
* Starts a new timer, or re-starts a stopped timer.
*/
def start(timerLabel: String): Unit = {
val tmpTime = System.nanoTime()
val currentTime = System.nanoTime()
if (starts.contains(timerLabel)) {
throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
s" timerLabel = $timerLabel before that timer was stopped.")
}
starts(timerLabel) = tmpTime
starts(timerLabel) = currentTime
}

/**
* Stops a timer and returns the elapsed time in seconds.
*/
def stop(timerLabel: String): Double = {
val tmpTime = System.nanoTime()
val currentTime = System.nanoTime()
if (!starts.contains(timerLabel)) {
throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
s" timerLabel = $timerLabel, but that timer was not started.")
}
val elapsed = tmpTime - starts(timerLabel)
val elapsed = currentTime - starts(timerLabel)
starts.remove(timerLabel)
if (totals.contains(timerLabel)) {
totals(timerLabel) += elapsed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ import org.apache.spark.rdd.RDD
* or any categorical feature used in regression or binary classification.
*
* @param label Label from LabeledPoint
* @param features Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
*/
private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable {
private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable {
}


private[tree] object TreePoint {

/**
Expand Down Expand Up @@ -76,7 +75,7 @@ private[tree] object TreePoint {
val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0 // offset by 1 for label
var featureIndex = 0
while (featureIndex < numFeatures) {
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
Expand All @@ -98,7 +97,6 @@ private[tree] object TreePoint {
new TreePoint(labeledPoint.label, arr)
}


/**
* Find bin for one (labeledPoint, feature).
*
Expand Down Expand Up @@ -129,11 +127,9 @@ private[tree] object TreePoint {
val highThreshold = bin.highSplit.threshold
if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
} else if (lowThreshold >= feature) {
right = mid - 1
}
else {
} else {
left = mid + 1
}
}
Expand Down Expand Up @@ -181,7 +177,8 @@ private[tree] object TreePoint {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1) {
throw new UnknownError("No bin was found for continuous feature." +
throw new RuntimeException("No bin was found for continuous feature." +
" This error can occur when given invalid data values (such as NaN)." +
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
}
binIndex
Expand All @@ -193,7 +190,8 @@ private[tree] object TreePoint {
sequentialBinSearchForOrderedCategoricalFeature()
}
if (binIndex == -1) {
throw new UnknownError("No bin was found for categorical feature." +
throw new RuntimeException("No bin was found for categorical feature." +
" This error can occur when given invalid data values (such as NaN)." +
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
}
binIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

package org.apache.spark.mllib.tree

import org.apache.spark.mllib.tree.impl.TreePoint

import scala.collection.JavaConverters._

import org.scalatest.FunSuite

import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.impl.TreePoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
Expand All @@ -43,10 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
prediction != expected.label
}
val accuracy = (input.length - numOffPredictions).toDouble / input.length
if (accuracy < requiredAccuracy) {
println(s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}
assert(accuracy >= requiredAccuracy)
assert(accuracy >= requiredAccuracy,
s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}

def validateRegressor(
Expand All @@ -59,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
err * err
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
}

test("split and bin calculation") {
Expand Down

0 comments on commit 6b5651e

Please sign in to comment.