Skip to content

Commit

Permalink
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up
Browse files Browse the repository at this point in the history
DecisionTree needs to match each example to a node at each iteration.  It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node.

Fix: Filter top-down using the partly built tree itself.

Major changes:
* Eliminated Filter class, findBinsForLevel() method.
* Set up node parent links in main loop over levels in train().
* Added predictNodeIndex() for filtering top-down.
* Added DTMetadata class

Other changes:
* Pre-compute set of unorderedFeatures.

Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]:
* The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy).

I've done initial tests indicating that this speeds things up, but am only now running large-scale ones.

CC: mengxr manishamde chouqin  Any comments are welcome---thanks!

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes apache#1975 from jkbradley/dt-opt2 and squashes the following commits:

a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata.  Small doc updates.
3726d20 [Joseph K. Bradley] Small code improvements based on code review.
ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow.
db0d773 [Joseph K. Bradley] scala style fix
6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code
931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2
797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level.  Needed to update treePointToNodeIndex with groupShift.
f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint
6b5651e [Joseph K. Bradley] Updates based on code review.  1 major change: persisting to memory + disk, not just memory.
2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used.  Removed debugging println calls in DecisionTree.scala.
356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification
b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes
b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
  • Loading branch information
jkbradley authored and mengxr committed Aug 17, 2014
1 parent fbad722 commit 73ab7f1
Show file tree
Hide file tree
Showing 9 changed files with 615 additions and 630 deletions.
878 changes: 386 additions & 492 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impl

import scala.collection.mutable

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD


/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
val numClasses: Int,
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

def isClassification: Boolean = numClasses >= 2

def isMulticlass: Boolean = numClasses > 2

def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)

def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)

def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)

}

private[tree] object DecisionTreeMetadata {

def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {

val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClassesForClassification
case Regression => 0
}

val maxBins = math.min(strategy.maxBins, numExamples).toInt
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)

val unorderedFeatures = new mutable.HashSet[Int]()
if (numClasses > 2) {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
if (k - 1 < log2MaxBinsp1) {
// Note: The above check is equivalent to checking:
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
unorderedFeatures.add(f)
} else {
// TODO: Allow this case, where we simply will know nothing about some categories?
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
}
}
} else {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
}
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
strategy.impurity, strategy.quantileCalculationStrategy)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -48,50 +47,35 @@ private[tree] object TreePoint {
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param strategy DecisionTree training info, used for dataset metadata.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
strategy: Strategy,
bins: Array[Array[Bin]]): RDD[TreePoint] = {
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
input.map { x =>
TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
strategy.categoricalFeaturesInfo)
TreePoint.labeledPointToTreePoint(x, bins, metadata)
}
}

/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
isMulticlassClassification: Boolean,
bins: Array[Array[Bin]],
categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
metadata: DecisionTreeMetadata): TreePoint = {

val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
if (isFeatureContinuous) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
bins, categoricalFeaturesInfo)
} else {
val featureCategories = featureInfo.get
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
isUnorderedFeature, bins, categoricalFeaturesInfo)
}
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
featureIndex += 1
}

Expand Down
18 changes: 14 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* Used for "binning" the features bins for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
* lowSplit.threshold < featureValue <= highSplit.threshold.
*
* For ordered categorical features, there is a 1-1-1 correspondence between
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for binary classification
* @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predictIfLeaf(features)
topNode.predict(features)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ class Node (

/**
* predict value if node is not leaf
* @param feature feature value
* @param features feature value
* @return predicted value
*/
def predictIfLeaf(feature: Vector) : Double = {
def predict(features: Vector) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
} else {
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
if (split.get.categories.contains(features(split.get.feature))) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
* :: DeveloperApi ::
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param threshold Threshold for continuous feature.
* Split left if feature <= threshold, else right.
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
* @param categories Split left if categorical feature value is in this set, else right.
*/
@DeveloperApi
case class Split(
Expand Down
Loading

0 comments on commit 73ab7f1

Please sign in to comment.