Skip to content

Commit

Permalink
add comments and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Oct 13, 2014
1 parent 9e7138e commit 8f46af6
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
53 changes: 41 additions & 12 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,12 @@ object DecisionTree extends Serializable with Logging {
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)

val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val featureSplits = findSplits(featureSamples, metadata.numSplits(featureIndex))
val numSplits = featureSplits.length
val numBins = numSplits + 1
logDebug("numSplits= " + numSplits)
metadata.setNumBinForFeature(featureIndex, numSplits + 1)
val numBins = metadata.numBins(featureIndex)

splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)

Expand Down Expand Up @@ -1017,12 +1015,23 @@ object DecisionTree extends Serializable with Logging {

/**
* Find splits for a continuous feature
* @param featureSamples
* @param numSplits
* @return
* NOTE: Returned number of splits is set based on `featureSamples` and
* may be different with `numSplits`.
* MetaData's number of splits will be set accordingly.
* @param featureSamples feature values of each sample
* @param metadata decision tree metadata
* @param featureIndex feature index to find splits
* @return array of splits
*/
private def findSplits(featureSamples: Array[Double], numSplits: Int): Array[Double] = {
/*
private[tree] def findSplitsForContinuousFeature(
featureSamples: Array[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
s"findSplitsForContinuousFeature can only be used " +
s"to find splits for a continuous feature.")

/**
* Get count for each distinct value
*/
def getValueCount(arr: Array[Double]): Array[(Double, Int)] = {
Expand All @@ -1040,24 +1049,41 @@ object DecisionTree extends Serializable with Logging {
}
index += 1
}
valueCount.append((currentValue, currentCount))
// last value is not put into valueCount
// because we should not use it as a split threshold

valueCount.toArray
}

val valueCount = getValueCount(featureSamples)
val numSplits = metadata.numSplits(featureIndex)

// sort feature samples first
val sortedFeatureSamples = featureSamples.sorted

// get count for each distinct value
val valueCount = getValueCount(sortedFeatureSamples)
if (valueCount.length <= numSplits) {
return valueCount.map(_._1)
}

// stride between splits
val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

// iterate `valueCount` to find splits
val splits = new ArrayBuffer[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCount(0)._2
// expectedCount: expected value for `currentCount`.
// If `currentCount` is closest value to `expectedCount`,
// then current value is a split threshold.
// After finding a split threshold, `expectedCount` is added by stride.
var expectedCount = stride
while (index < valueCount.length) {
// If adding count of current value to currentCount
// makes currentCount less close to expectedCount,
// previous value is a split threshold.
if (math.abs(currentCount - expectedCount) <
math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
splits.append(valueCount(index-1)._1)
Expand All @@ -1067,6 +1093,9 @@ object DecisionTree extends Serializable with Logging {
index += 1
}

// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)

splits.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ private[tree] class DecisionTreeMetadata(


/**
*
* Set number of splits for a continuous feature.
* For a continuous feature, number of bins is number of splits plus 1.
*/
def setNumBinForFeature(featureIndex: Int, numBin: Int) {
def setNumSplits(featureIndex: Int, numSplits: Int) {
require(isContinuous(featureIndex),
s"Can only set number of bin for continuous feature.")
numBins(featureIndex) = numBin
s"Only number of bin for a continuous feature can be set.")
numBins(featureIndex) = numSplits + 1
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
Expand Down Expand Up @@ -102,6 +102,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}

test("find splits for a continuous feature") {
// find splits for normal case
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array.fill(200000)(math.random)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5)
assert(fakeMetadata.numBins(0) === 6)
}

// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
val featureSamples = Array(1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
assert(fakeMetadata.numSplits(0) === 2)
assert(fakeMetadata.numBins(0) === 3)
}
}

test("Multiclass classification with unordered categorical features:" +
" split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
Expand Down

0 comments on commit 8f46af6

Please sign in to comment.