Skip to content

Commit

Permalink
Eliminated pre-allocated nodes array in main train() method.
Browse files Browse the repository at this point in the history
* Nodes are constructed and added to the tree structure as needed during training.

Moved tree construction from main train() method into findBestSplitsPerGroup() since there is no need to keep the (split, gain) array for an entire level of nodes.  Only one element of that array is needed at a time, so we do not the array.

findBestSplits() now returns 2 items:
* rootNode (newly created root node on first iteration, same root node on later iterations)
* doneTraining (indicating if all nodes at that level were leafs)

Also:
* Added Node.deepCopy (private[tree]), used for test suite
* Updated test suite (same functionality)
  • Loading branch information
jkbradley committed Sep 10, 2014
1 parent 2ab763b commit 1a8f0ad
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 172 deletions.
133 changes: 68 additions & 65 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val maxDepth = strategy.maxDepth
require(maxDepth <= 30,
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
// Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodesPlus1)

// Calculate level for single group construction

Expand Down Expand Up @@ -118,61 +114,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* beforehand and is not used in later levels.
*/

var topNode: Node = null // set on first iteration
var level = 0
var break = false
while (level <= maxDepth && !break) {

logDebug("#####################################")
logDebug("level = " + level)
logDebug("#####################################")

// Find best split for all nodes at a level.
timer.start("findBestSplits")
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
DecisionTree.findBestSplits(treeInput,
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
timer.stop("findBestSplits")

val levelNodeIndexOffset = Node.startIndexInLevel(level)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
val nodeIndex = levelNodeIndexOffset + index

// Extract info for this node (index) at the current level.
timer.start("extractNodeInfo")
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
timer.stop("extractNodeInfo")

if (level != 0) {
// Set parent.
val parentNodeIndex = Node.parentIndex(nodeIndex)
if (Node.isLeftChild(nodeIndex)) {
nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
} else {
nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
}
}
if (level < maxDepth) {
logDebug("leftChildIndex = " + Node.leftChildIndex(nodeIndex) +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + Node.rightChildIndex(nodeIndex) +
", impurity = " + stats.rightImpurity)
}
logDebug("final best split = " + split)
if (level == 0) {
topNode = tmpTopNode
}
require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
// Check whether all the nodes at the current level at leaves.
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
logDebug("all leaf = " + allLeaf)
if (allLeaf) {
break = true // no more tree construction
} else {
level += 1
if (doneTraining) {
break = true
logDebug("done training")
}

level += 1
}

logDebug("#####################################")
Expand All @@ -184,7 +148,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

new DecisionTreeModel(nodes(1), strategy.algo)
new DecisionTreeModel(topNode, strategy.algo)
}

}
Expand Down Expand Up @@ -398,17 +362,19 @@ object DecisionTree extends Serializable with Logging {
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
* @return array (over nodes) of splits with best split for each node at a given level.
* @return (root, doneTraining) where:
* root = Root node (which is newly created on the first iteration),
* doneTraining = true if no more internal nodes were created.
*/
private[tree] def findBestSplits(
input: RDD[TreePoint],
metadata: DecisionTreeMetadata,
level: Int,
nodes: Array[Node],
topNode: Node,
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
maxLevelForSingleGroup: Int,
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
Expand All @@ -417,18 +383,18 @@ object DecisionTree extends Serializable with Logging {
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
val numGroups = 1 << level - maxLevelForSingleGroup
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
var doneTraining = true
while (groupIndex < numGroups) {
val bestSplitsForGroup = findBestSplitsPerGroup(input, metadata, level,
nodes, splits, bins, timer, numGroups, groupIndex)
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
topNode, splits, bins, timer, numGroups, groupIndex)
doneTraining = doneTraining && doneTrainingGroup
groupIndex += 1
}
bestSplits
(topNode, doneTraining) // Not first iteration, so topNode was already set.
} else {
findBestSplitsPerGroup(input, metadata, level, nodes, splits, bins, timer)
findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
}
}

Expand Down Expand Up @@ -570,23 +536,25 @@ object DecisionTree extends Serializable with Logging {
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param level Level of the tree
* @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
* @param topNode Root node of the tree (or invalid node when training first level).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param numGroups total number of node groups at the current level. Default value is set to 1.
* @param groupIndex index of the node group being processed. Default value is set to 0.
* @return array of splits with best splits for all nodes at a given level.
* @return (root, doneTraining) where:
* root = Root node (which is newly created on the first iteration),
* doneTraining = true if no more internal nodes were created.
*/
private def findBestSplitsPerGroup(
input: RDD[TreePoint],
metadata: DecisionTreeMetadata,
level: Int,
nodes: Array[Node],
topNode: Node,
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
timer: TimeTracker,
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
groupIndex: Int = 0): (Node, Boolean) = {

/*
* The high-level descriptions of the best split optimizations are noted here.
Expand Down Expand Up @@ -643,7 +611,7 @@ object DecisionTree extends Serializable with Logging {
0
} else {
val globalNodeIndex =
predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
globalNodeIndex - globalNodeIndexOffset
}
}
Expand Down Expand Up @@ -686,18 +654,53 @@ object DecisionTree extends Serializable with Logging {

// Calculate best splits for all nodes at a given level
timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
// On the first iteration, we need to get and return the newly created root node.
var newTopNode: Node = topNode
// Iterating over all nodes at this level
var nodeIndex = 0
var internalNodeCount = 0
while (nodeIndex < numNodes) {
bestSplits(nodeIndex) =
val (split: Split, stats: InformationGainStats) =
binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
logDebug("best split = " + bestSplits(nodeIndex)._1)
logDebug("best split = " + split)

val globalNodeIndex = globalNodeIndexOffset + nodeIndex

// Extract info for this node at the current level.
timer.start("extractNodeInfo")
val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
val node =
new Node(globalNodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
timer.stop("extractNodeInfo")

if (!isLeaf) {
internalNodeCount += 1
}
if (level == 0) {
newTopNode = node
} else {
// Set parent.
val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
if (Node.isLeftChild(globalNodeIndex)) {
parentNode.leftNode = Some(node)
} else {
parentNode.rightNode = Some(node)
}
}
if (level < metadata.maxDepth) {
logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
", impurity = " + stats.rightImpurity)
}

nodeIndex += 1
}
timer.stop("chooseSplits")

bestSplits
val doneTraining = internalNodeCount == 0
(newTopNode, doneTraining)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private[tree] class DecisionTreeMetadata(
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {
val quantileStrategy: QuantileStrategy,
val maxDepth: Int) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

Expand Down Expand Up @@ -127,7 +128,7 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth)
}

/**
Expand Down
35 changes: 35 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ class Node (
}
}

/**
* Returns a deep copy of the subtree rooted at this node.
*/
private[tree] def deepCopy(): Node = {
val leftNodeCopy = if (leftNode.isEmpty) {
None
} else {
Some(leftNode.get.deepCopy())
}
val rightNodeCopy = if (rightNode.isEmpty) {
None
} else {
Some(rightNode.get.deepCopy())
}
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}

/**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
Expand Down Expand Up @@ -190,4 +207,22 @@ private[tree] object Node {
*/
def startIndexInLevel(level: Int): Int = 1 << level

/**
* Traces down from a root node to get the node with the given node index.
* This assumes the node exists.
*/
def getNode(nodeIndex: Int, rootNode: Node): Node = {
var tmpNode: Node = rootNode
var levelsToGo = indexToLevel(nodeIndex)
while (levelsToGo > 0) {
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
tmpNode = tmpNode.leftNode.get
} else {
tmpNode = tmpNode.rightNode.get
}
levelsToGo -= 1
}
tmpNode
}

}
Loading

0 comments on commit 1a8f0ad

Please sign in to comment.