Skip to content

Commit

Permalink
[SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-allocat…
Browse files Browse the repository at this point in the history
…ed nodes, parentImpurities arrays. Memory calc bug fix.

This PR includes some code simplifications and re-organization which will be helpful for implementing random forests.  The main changes are that the nodes and parentImpurities arrays are no longer pre-allocated in the main train() method.

Also added 2 bug fixes:
* maxMemoryUsage calculation
* over-allocation of space for bins in DTStatsAggregator for unordered features.

Relation to RFs:
* Since RFs will be deeper and will therefore be more likely sparse (not full trees), it could be a cost savings to avoid pre-allocating a full tree.
* The associated re-organization also reduces bookkeeping, which will make RFs easier to implement.
* The return code doneTraining may be generalized to include cases such as nodes ready for local training.

Details:

No longer pre-allocate parentImpurities array in main train() method.
* parentImpurities values are now stored in individual nodes (in Node.stats.impurity).
* These were not really needed.  They were used in calculateGainForSplit(), but they can be calculated anyways using parentNodeAgg.

No longer using Node.build since tree structure is constructed on-the-fly.
* Did not eliminate since it is public (Developer) API.  Marked as deprecated.

Eliminated pre-allocated nodes array in main train() method.
* Nodes are constructed and added to the tree structure as needed during training.
* Moved tree construction from main train() method into findBestSplitsPerGroup() since there is no need to keep the (split, gain) array for an entire level of nodes.  Only one element of that array is needed at a time, so we do not the array.

findBestSplits() now returns 2 items:
* rootNode (newly created root node on first iteration, same root node on later iterations)
* doneTraining (indicating if all nodes at that level were leafs)

Updated DecisionTreeSuite.  Notes:
* Improved test "Second level node building with vs. without groups"
** generateOrderedLabeledPoints() modified so that it really does require 2 levels of internal nodes.
* Related update: Added Node.deepCopy (private[tree]), used for test suite

CC: mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #2341 from jkbradley/dt-spark-3160 and squashes the following commits:

07dd1ee [Joseph K. Bradley] Fixed overflow bug with computing maxMemoryUsage in DecisionTree.  Also fixed bug with over-allocating space in DTStatsAggregator for unordered features.
debe072 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1
0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc
eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix
d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated
d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training.
2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code:
  • Loading branch information
jkbradley authored and mengxr committed Sep 12, 2014
1 parent 42904b8 commit b8634df
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 256 deletions.
191 changes: 80 additions & 111 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
require(minInstancesPerNode >= 1,
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")

val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
if (isUnordered(featureIndex)) {
total + 2 * numBins(featureIndex)
} else {
total + numBins(featureIndex)
}
}
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}

/**
Expand Down Expand Up @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
val minInfoGain: Double) extends Serializable {

Expand Down Expand Up @@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
Expand Down
37 changes: 37 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
@deprecated("build should no longer be used since trees are constructed on-the-fly in training",
"1.2.0")
def build(nodes: Array[Node]): Unit = {
logDebug("building node " + id + " at level " + Node.indexToLevel(id))
logDebug("id = " + id + ", split = " + split)
Expand Down Expand Up @@ -93,6 +95,23 @@ class Node (
}
}

/**
* Returns a deep copy of the subtree rooted at this node.
*/
private[tree] def deepCopy(): Node = {
val leftNodeCopy = if (leftNode.isEmpty) {
None
} else {
Some(leftNode.get.deepCopy())
}
val rightNodeCopy = if (rightNode.isEmpty) {
None
} else {
Some(rightNode.get.deepCopy())
}
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}

/**
* Get the number of nodes in tree below this node, including leaf nodes.
* E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
Expand Down Expand Up @@ -190,4 +209,22 @@ private[tree] object Node {
*/
def startIndexInLevel(level: Int): Int = 1 << level

/**
* Traces down from a root node to get the node with the given node index.
* This assumes the node exists.
*/
def getNode(nodeIndex: Int, rootNode: Node): Node = {
var tmpNode: Node = rootNode
var levelsToGo = indexToLevel(nodeIndex)
while (levelsToGo > 0) {
if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
tmpNode = tmpNode.leftNode.get
} else {
tmpNode = tmpNode.rightNode.get
}
levelsToGo -= 1
}
tmpNode
}

}
Loading

0 comments on commit b8634df

Please sign in to comment.