Skip to content

Commit

Permalink
Optimizations + Bug fix for DecisionTree
Browse files Browse the repository at this point in the history
Optimization: Added TreePoint representation so we only call findBin once for each example, feature.

Also, calculateGainsForAllNodeSplits now only searches over actual splits, not empty/unused ones.

BUG FIX: isSampleValid
* isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories.
* exhibited for unordered features (multi-class classification with categorical features of low arity)
* Fix: Index bins correctly for unordered categorical features.

Also: some commented-out debugging println calls in DecisionTree, to be removed later
  • Loading branch information
jkbradley committed Aug 8, 2014
1 parent 3211f02 commit 0f676e2
Showing 1 changed file with 95 additions and 28 deletions.
123 changes: 95 additions & 28 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// noting the parents filters for the child nodes
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
//println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
for (filter <- filters(nodeIndex)) {
logDebug("Filter = " + filter)
}
Expand Down Expand Up @@ -477,7 +478,7 @@ object DecisionTree extends Serializable with Logging {
* @param splits possible splits for all features
* @param bins possible bins for all features
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array of splits with best splits for all nodes at a given level.
* @return array (over nodes) of splits with best split for each node at a given level.
*/
protected[tree] def findBestSplits(
input: RDD[TreePoint],
Expand All @@ -490,6 +491,7 @@ object DecisionTree extends Serializable with Logging {
maxLevelForSingleGroup: Int,
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
//println(s"findBestSplits: level = $level")
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
// the nodes are divided into multiple groups at each level with the number of groups
Expand Down Expand Up @@ -617,22 +619,32 @@ object DecisionTree extends Serializable with Logging {
val featureIndex = filter.split.feature
val comparison = filter.comparison
val isFeatureContinuous = filter.split.featureType == Continuous
val binId = treePoint.features(featureIndex)
val bin = bins(featureIndex)(binId)
if (isFeatureContinuous) {
val binId = treePoint.features(featureIndex)
val bin = bins(featureIndex)(binId)
val featureValue = bin.highSplit.threshold
val threshold = filter.split.threshold
comparison match {
case -1 => if (featureValue > threshold) return false
case 1 => if (featureValue <= threshold) return false
}
} else {
val containsFeature = filter.split.categories.contains(bin.category)
val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits =
numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
val featureValue = if (isUnorderedFeature) {
treePoint.features(featureIndex)
} else {
val binId = treePoint.features(featureIndex)
bins(featureIndex)(binId).category
}
val containsFeature = filter.split.categories.contains(featureValue)
comparison match {
case -1 => if (!containsFeature) return false
case 1 => if (containsFeature) return false
}

}
}

Expand Down Expand Up @@ -669,6 +681,7 @@ object DecisionTree extends Serializable with Logging {
val parentFilters = findParentFilters(nodeIndex)
// Find out whether the sample qualifies for the particular node.
val sampleValid = isSampleValid(parentFilters, treePoint)
//println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
// Mark one bin as -1 is sufficient.
Expand Down Expand Up @@ -739,6 +752,7 @@ object DecisionTree extends Serializable with Logging {
label: Double,
agg: Array[Double],
rightChildShift: Int): Unit = {
//println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
// Find the bin index for this feature.
val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
val featureValue = arr(arrIndex).toInt
Expand Down Expand Up @@ -792,6 +806,8 @@ object DecisionTree extends Serializable with Logging {
}
}

val rightChildShift = numClasses * numBins * numFeatures * numNodes

/**
* Helper for binSeqOp.
*
Expand All @@ -814,8 +830,11 @@ object DecisionTree extends Serializable with Logging {
// Check whether the instance was valid for this nodeIndex.
val validSignalIndex = 1 + numFeatures * nodeIndex
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
if (level == 1) {
val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
//println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
}
if (isSampleValidForNode) {
val rightChildShift = numClasses * numBins * numFeatures * numNodes
// actual class label
val label = arr(0)
// Iterate over all features.
Expand Down Expand Up @@ -874,7 +893,7 @@ object DecisionTree extends Serializable with Logging {
val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
agg(aggIndex) = agg(aggIndex) + 1
agg(aggIndex + 1) = agg(aggIndex + 1) + label
agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
featureIndex += 1
}
}
Expand Down Expand Up @@ -944,6 +963,29 @@ object DecisionTree extends Serializable with Logging {
logDebug("binAggregates.length = " + binAggregates.length)

timer.binAggregatesTime += timer.elapsed()
//2 * numClasses * numBins * numFeatures * numNodes for unordered features.
// (left/right, node, feature, bin, label)
/*
println(s"binAggregates:")
for (i <- Range(0,2)) {
for (n <- Range(0,numNodes)) {
for (f <- Range(0,numFeatures)) {
for (b <- Range(0,4)) {
for (c <- Range(0,numClasses)) {
val idx = i * numClasses * numBins * numFeatures * numNodes +
n * numClasses * numBins * numFeatures +
f * numBins * numFeatures +
b * numFeatures +
c
if (binAggregates(idx) != 0) {
println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
}
}
}
}
}
}
*/

/**
* Calculates the information gain for all splits based upon left/right split aggregates.
Expand Down Expand Up @@ -985,6 +1027,7 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftTotalCount + rightTotalCount
if (totalCount == 0) {
// Return arbitrary prediction.
//println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
}

Expand All @@ -997,13 +1040,23 @@ object DecisionTree extends Serializable with Logging {
def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1, Double.MinValue, 0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
else (maxIndex, maxValue, currentIndex + 1)
if (currentValue > maxValue) {
(currentIndex, currentValue, currentIndex + 1)
} else {
(maxIndex, maxValue, currentIndex + 1)
}
}
if (result._1 < 0) 0 else result._1
if (result._1 < 0) {
throw new RuntimeException("DecisionTree internal error:" +
" calculateGainForSplit failed in indexOfLargestArrayElement")
}
result._1
}

val predict = indexOfLargestArrayElement(leftRightCounts)
if (predict == 0 && featureIndex == 0 && splitIndex == 0) {
//println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
}
val prob = leftRightCounts(predict) / totalCount

val leftImpurity = if (leftTotalCount == 0) {
Expand All @@ -1023,6 +1076,7 @@ object DecisionTree extends Serializable with Logging {
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)

case Regression =>
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
Expand Down Expand Up @@ -1140,6 +1194,7 @@ object DecisionTree extends Serializable with Logging {

val rightChildShift = numClasses * numBins * numFeatures
var splitIndex = 0
var TMPDEBUG = 0.0
while (splitIndex < numBins - 1) {
var classIndex = 0
while (classIndex < numClasses) {
Expand All @@ -1149,10 +1204,12 @@ object DecisionTree extends Serializable with Logging {
val rightBinValue = binData(rightChildShift + shift + classIndex)
leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
TMPDEBUG += leftBinValue + rightBinValue
classIndex += 1
}
splitIndex += 1
}
//println(s"found Agg: $TMPDEBUG")
}

def findAggForRegression(
Expand Down Expand Up @@ -1247,14 +1304,36 @@ object DecisionTree extends Serializable with Logging {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numBins - 1) {
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
for (splitIndex <- 0 until numSplitsForFeature) {
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
splitIndex, rightNodeAgg, nodeImpurity)
}
}
gains
}

/**
* Get the number of splits for a feature.
*/
def getNumSplitsForFeature(featureIndex: Int): Int = {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
numBins - 1
} else {
// Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits =
numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else {
// Ordered features
featureCategories
}
}
}

/**
* Find the best split for a node.
* @param binData Bin data slice for this node, given by getBinDataForNode.
Expand All @@ -1273,7 +1352,7 @@ object DecisionTree extends Serializable with Logging {
// Calculate gains for all splits.
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)

val (bestFeatureIndex,bestSplitIndex, gainStats) = {
val (bestFeatureIndex, bestSplitIndex, gainStats) = {
// Initialize with infeasible values.
var bestFeatureIndex = Int.MinValue
var bestSplitIndex = Int.MinValue
Expand All @@ -1283,27 +1362,14 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
val maxSplitIndex: Double = {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
numBins - 1
} else { // Categorical feature
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
math.pow(2.0, featureCategories - 1).toInt - 1
} else { // Binary classification
featureCategories
}
}
}
while (splitIndex < maxSplitIndex) {
val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
while (splitIndex < numSplitsForFeature) {
val gainStats = gains(featureIndex)(splitIndex)
if (gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
bestFeatureIndex = featureIndex
bestSplitIndex = splitIndex
//println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
}
splitIndex += 1
}
Expand Down Expand Up @@ -1361,6 +1427,7 @@ object DecisionTree extends Serializable with Logging {
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
logDebug("parent node impurity = " + parentNodeImpurity)
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
//println(s"bestSplits(node:$node): ${bestSplits(node)}")
node += 1
}
timer.chooseSplitsTime += timer.elapsed()
Expand Down

0 comments on commit 0f676e2

Please sign in to comment.