diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 98596569b8c95..56bb8812100a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val maxDepth = strategy.maxDepth require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 - val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) - // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodesPlus1) - // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") // TODO: Calculate memory usage more precisely. val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) @@ -120,81 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * beforehand and is not used in later levels. */ + var topNode: Node = null // set on first iteration var level = 0 var break = false while (level <= maxDepth && !break) { - logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] = - DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, + metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = Node.startIndexInLevel(level) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - val nodeIndex = levelNodeIndexOffset + index - - // Extract info for this node (index) at the current level. - timer.start("extractNodeInfo") - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val predict = nodeSplitStats._3.predict - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - timer.stop("extractNodeInfo") - - if (level != 0) { - // Set parent. - val parentNodeIndex = Node.parentIndex(nodeIndex) - if (Node.isLeftChild(nodeIndex)) { - nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) - } else { - nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) - } - } - // Extract info for nodes at the next lower level. - timer.start("extractInfoForLowerLevels") - if (level < maxDepth) { - val leftChildIndex = Node.leftChildIndex(nodeIndex) - val leftImpurity = stats.leftImpurity - logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) - parentImpurities(leftChildIndex) = leftImpurity - - val rightChildIndex = Node.rightChildIndex(nodeIndex) - val rightImpurity = stats.rightImpurity - logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) - parentImpurities(rightChildIndex) = rightImpurity - } - timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + split) + if (level == 0) { + topNode = tmpTopNode } - require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves. - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) - logDebug("all leaf = " + allLeaf) - if (allLeaf) { - break = true // no more tree construction - } else { - level += 1 + if (doneTraining) { + break = true + logDebug("done training") } + + level += 1 } logDebug("#####################################") logDebug("Extracting tree model") logDebug("#####################################") - // Initialize the top or root node of the tree. - val topNode = nodes(1) - // Build the full tree using the node info calculated in the level-wise best split calculations. - topNode.build(nodes) - timer.stop("total") logInfo("Internal timing for DecisionTree:") @@ -409,24 +357,26 @@ object DecisionTree extends Serializable with Logging { * multiple groups if the level-wise training task could lead to memory overflow. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array (over nodes) of splits with best split for each node at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private[tree] def findBestSplits( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = { + timer: TimeTracker = new TimeTracker): (Node, Boolean) = { + // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -435,18 +385,18 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats, Predict)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 + var doneTraining = true while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, - nodes, splits, bins, timer, numGroups, groupIndex) - bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, + topNode, splits, bins, timer, numGroups, groupIndex) + doneTraining = doneTraining && doneTrainingGroup groupIndex += 1 } - bestSplits + (topNode, doneTraining) // Not first iteration, so topNode was already set. } else { - findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) + findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) } } @@ -586,27 +536,27 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return array of splits with best splits for all nodes at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private def findBestSplitsPerGroup( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = { + groupIndex: Int = 0): (Node, Boolean) = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -663,7 +613,7 @@ object DecisionTree extends Serializable with Logging { 0 } else { val globalNodeIndex = - predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) + predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) globalNodeIndex - globalNodeIndexOffset } } @@ -706,33 +656,63 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes) - // Iterating over all nodes at this level + // On the first iteration, we need to get and return the newly created root node. + var newTopNode: Node = topNode + + // Iterate over all nodes at this level var nodeIndex = 0 + var internalNodeCount = 0 while (nodeIndex < numNodes) { - val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) - logDebug("node impurity = " + nodeImpurity) - bestSplits(nodeIndex) = - binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) - logDebug("best split = " + bestSplits(nodeIndex)._1) + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) + logDebug("best split = " + split) + + val globalNodeIndex = globalNodeIndexOffset + nodeIndex + + // Extract info for this node at the current level. + val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) + val node = + new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + + if (!isLeaf) { + internalNodeCount += 1 + } + if (level == 0) { + newTopNode = node + } else { + // Set parent. + val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) + if (Node.isLeftChild(globalNodeIndex)) { + parentNode.leftNode = Some(node) + } else { + parentNode.rightNode = Some(node) + } + } + if (level < metadata.maxDepth) { + logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + + ", impurity = " + stats.rightImpurity) + } + nodeIndex += 1 } timer.stop("chooseSplits") - bestSplits + val doneTraining = internalNodeCount == 0 + (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftImpurityCalculator.count @@ -747,14 +727,10 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - // impurity of parent node - val impurity = if (level > 0) { - topImpurity - } else { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - parentNodeAgg.calculate() - } + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + + val impurity = parentNodeAgg.calculate() val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -795,19 +771,15 @@ object DecisionTree extends Serializable with Logging { * Find the best split for a node. * @param binAggregates Bin statistics. * @param nodeIndex Index for node to split in this (level, group). - * @param nodeImpurity Impurity of the node (nodeIndex). * @return tuple for best split: (Split, information gain) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - nodeImpurity: Double, level: Int, metadata: DecisionTreeMetadata, splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { - logDebug("node impurity = " + nodeImpurity) - // calculate predict only once var predict: Option[Predict] = None @@ -831,8 +803,7 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -845,8 +816,7 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -917,8 +887,7 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -937,8 +906,8 @@ object DecisionTree extends Serializable with Logging { /** * Get the number of values to be stored per node in the bin aggregates. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { - val totalBins = metadata.numBins.sum + private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { + val totalBins = metadata.numBins.map(_.toLong).sum if (metadata.isClassification) { metadata.numClasses * totalBins } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 987fe632c91ed..31d1e8ac30eea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -75,6 +75,9 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + val isMulticlassClassification = algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 866d85a79bea1..61a94246711bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator( * Offset for each feature for calculating indices into the [[allStats]] array. */ private val featureOffsets: Array[Int] = { - def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { - if (isUnordered(featureIndex)) { - total + 2 * numBins(featureIndex) - } else { - total + numBins(featureIndex) - } - } - Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } /** @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator( s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + s" but was called for ordered feature $featureIndex.") val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 5ceaa8154d11a..b6d49e5555b1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata( val numBins: Array[Int], val impurity: Impurity, val quantileStrategy: QuantileStrategy, + val maxDepth: Int, val minInstancesPerNode: Int, val minInfoGain: Double) extends Serializable { @@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, strategy.minInstancesPerNode, strategy.minInfoGain) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0594fd0749d21..271b2c4ad813e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Predict values for the given data set using the model trained. * * @param features RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return RDD of predictions for each of the given data points */ def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5b8a4cbed2306..5f0095d23c7ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -55,6 +55,8 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @deprecated("build should no longer be used since trees are constructed on-the-fly in training", + "1.2.0") def build(nodes: Array[Node]): Unit = { logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) @@ -93,6 +95,23 @@ class Node ( } } + /** + * Returns a deep copy of the subtree rooted at this node. + */ + private[tree] def deepCopy(): Node = { + val leftNodeCopy = if (leftNode.isEmpty) { + None + } else { + Some(leftNode.get.deepCopy()) + } + val rightNodeCopy = if (rightNode.isEmpty) { + None + } else { + Some(rightNode.get.deepCopy()) + } + new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + } + /** * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. @@ -190,4 +209,22 @@ private[tree] object Node { */ def startIndexInLevel(level: Int): Int = 1 << level + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: Node): Node = { + var tmpNode: Node = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftNode.get + } else { + tmpNode = tmpNode.rightNode.get + } + levelsToGo -= 1 + } + tmpNode + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index fd8547c1660fc..1bd7ea05c46c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -270,19 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 0) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode: Node, doneTraining: Boolean) = + DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 - val predict = bestSplits(0)._3 + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(predict.predict === 1) - assert(predict.prob === 0.6) + assert(rootNode.predict === 1) assert(stats.impurity > 0.2) } @@ -303,19 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories.length === 1) assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 - val predict = bestSplits(0)._3.predict + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(predict === 0.6) + assert(rootNode.predict === 0.6) assert(stats.impurity > 0.2) } @@ -356,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) } test("Binary classification stump with fixed label 1 for Gini") { @@ -382,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -409,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -436,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._3.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Second level node building with vs. without groups") { @@ -459,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](8) - nodes(1) = modelOneNode.topNode - nodes(1).leftNode = None - nodes(1).rightNode = None - - val parentImpurities = Array(0, 0.5, 0.5, 0.5) + val rootNodeCopy1 = modelOneNode.topNode.deepCopy() + val rootNodeCopy2 = modelOneNode.topNode.deepCopy() // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, - splits, bins, 10) - assert(bestSplits.length === 2) - assert(bestSplits(0)._2.gain > 0) - assert(bestSplits(1)._2.gain > 0) + val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy1, splits, bins, 10) + assert(rootNode.leftNode.nonEmpty) + assert(rootNode.rightNode.nonEmpty) + val children1 = new Array[Node](2) + children1(0) = rootNode.leftNode.get + children1(1) = rootNode.rightNode.get // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, - nodes, splits, bins, 0) - assert(bestSplitsWithGroups.length === 2) - assert(bestSplitsWithGroups(0)._2.gain > 0) - assert(bestSplitsWithGroups(1)._2.gain > 0) + val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy2, splits, bins, 0) + assert(rootNode2.leftNode.nonEmpty) + assert(rootNode2.rightNode.nonEmpty) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get // Verify whether the splits obtained using single group and multiple group level // construction strategies are the same. - for (i <- 0 until bestSplits.length) { - assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) - assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) - assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) - assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) - assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict) + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict === children2(i).predict) } } @@ -508,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { @@ -573,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) - val gain = bestSplits(0)._2 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) + + val gain = rootNode.stats.get assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) } @@ -600,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } @@ -627,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } test("Multiclass classification stump with 10-ary (ordered) categorical features") { @@ -652,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1.0)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) } test("Multiclass classification tree with 10-ary (ordered) categorical features," + @@ -698,12 +707,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestInfoStats = bestSplits(0)._2 - assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) } test("don't choose split that doesn't satisfy min instance per node requirements") { @@ -722,14 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestSplit = bestSplits(0)._1 - val bestSplitStats = bestSplits(0)._1 - assert(bestSplit.feature == 1) - assert(bestSplitStats != InformationGainStats.invalidInformationGainStats) + val split = rootNode.split.get + val gain = rootNode.stats.get + assert(split.feature == 1) + assert(gain != InformationGainStats.invalidInformationGainStats) } test("split must satisfy min info gain requirements") { @@ -754,12 +761,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length == 1) - val bestInfoStats = bestSplits(0)._2 - assert(bestInfoStats == InformationGainStats.invalidInformationGainStats) + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) } } @@ -786,13 +792,16 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600) { - val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + val label = if (i < 100) { + 0.0 + } else if (i < 500) { + 1.0 + } else if (i < 900) { + 0.0 } else { - val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + 1.0 } + arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i)) } arr }