From dbb7ac13d28fba0848062a7bea40c617cb5f2c80 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 22 Jan 2014 20:44:23 -0800 Subject: [PATCH] categorical feature support Signed-off-by: Manish Amde --- .../spark/mllib/tree/DecisionTree.scala | 127 +++++++++++++----- .../apache/spark/mllib/tree/model/Bin.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 15 ++- .../apache/spark/mllib/tree/model/Split.scala | 11 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 83 ++++++++++-- 5 files changed, 185 insertions(+), 53 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index e8adef377481c..f89c53a7ad70d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { //Cache input RDD for speedup during multiple passes input.cache() - val (splits, bins) = DecisionTree.find_splits_bins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) strategy.numBins = bins(0).length @@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("algo = " + strategy.algo) - - breakable { for (level <- 0 until maxDepth){ @@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging { val featureIndex = filter.split.feature val threshold = filter.split.threshold val comparison = filter.comparison - comparison match { - case(-1) => if (features(featureIndex) > threshold) return false - case(0) => if (features(featureIndex) != threshold) return false - case(1) => if (features(featureIndex) <= threshold) return false + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous){ + comparison match { + case(-1) => if (feature > threshold) return false + case(1) => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case(-1) => if (!containsFeature) return false + case(1) => if (containsFeature) return false + } + } } true @@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { //logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex)) - //TODO: Do binary search - for (binIndex <- 0 until strategy.numBins) { - val bin = bins(featureIndex)(binIndex) - //TODO: Remove this requirement post basic functional - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - val features = labeledPoint.features - if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { - return binIndex + + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous){ + //TODO: Do binary search + for (binIndex <- 0 until strategy.numBins) { + val bin = bins(featureIndex)(binIndex) + //TODO: Remove this requirement post basic functional + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + val features = labeledPoint.features + if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { + return binIndex + } + } + throw new UnknownError("no bin was found for continuous variable.") + } else { + for (binIndex <- 0 until strategy.numBins) { + val bin = bins(featureIndex)(binIndex) + //TODO: Remove this requirement post basic functional + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } } + throw new UnknownError("no bin was found for categorical variable.") + } - throw new UnknownError("no bin was found.") } @@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging { @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an Array[Array[Bin]] of size (numFeatures,numSplits1) */ - def find_splits_bins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { + def findSplitsBins(input : RDD[LabeledPoint], strategy : Strategy) : (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() @@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) for (index <- 0 until numBins-1) { val sampleIndex = (index+1)*stride.toInt - val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous) + val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split } } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - for (index <- 0 until maxFeatureValue){ - //TODO: Sort by centriod - val split = new Split(featureIndex,index,Categorical) - splits(featureIndex)(index) = split + + require(maxFeatureValue < numBins, "number of categories should be less than number of bins") + + val centriodForCategories + = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + //Checking for missing categorical variables + val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue){ + if (centriodForCategories.contains(i)){ + fullCentriodForCategories(i) = centriodForCategories(i) + } else { + fullCentriodForCategories(i) = Double.MaxValue + } + } + + val categoriesSortedByCentriod + = fullCentriodForCategories.toList sortBy {_._2} + + logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentriod.iterator.zipWithIndex foreach { + case((key, value), index) => { + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex,Double.MinValue,Categorical,categoriesForSplit) + bins(featureIndex)(index) = { + if(index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex,Categorical),splits(featureIndex)(0),Categorical,key) + } + else { + new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Categorical,key) + } + } + } } } } //Find all bins for (featureIndex <- 0 until numFeatures){ - bins(featureIndex)(0) - = new Bin(new DummyLowSplit(Continuous),splits(featureIndex)(0),Continuous) - for (index <- 1 until numBins - 1){ - val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous) - bins(featureIndex)(index) = bin + val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinous) { //bins for categorical variables are already assigned + bins(featureIndex)(0) + = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1),splits(featureIndex)(index),Continuous,Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) + = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous),Continuous,Double.MinValue) + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + for (i <- maxFeatureValue until numBins){ + bins(featureIndex)(i) + = new Bin(new DummyCategoricalSplit(featureIndex,Categorical),new DummyCategoricalSplit(featureIndex,Categorical),Categorical,Double.MaxValue) + } } - bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(Continuous),Continuous) } - (splits,bins) } case MinMax => { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 13191851956ad..6664f084a7d8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -18,6 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ -case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType) { +case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index a9210e10ae48b..fb63743848cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.FeatureType._ class Node ( val id : Int, val predict : Double, @@ -49,10 +50,18 @@ class Node ( val id : Int, if (isLeaf) { predict } else{ - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } } else { - rightNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 01aa349115302..97f16e67c55b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,11 +18,14 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType -case class Split(feature: Int, threshold : Double, featureType : FeatureType){ - override def toString = "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +case class Split(feature: Int, threshold : Double, featureType : FeatureType, categories : List[Double]){ + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + ", categories = " + categories } -class DummyLowSplit(kind : FeatureType) extends Split(Int.MinValue, Double.MinValue, kind) +class DummyLowSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MinValue, kind, List()) -class DummyHighSplit(kind : FeatureType) extends Split(Int.MaxValue, Double.MaxValue, kind) +class DummyHighSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) + +class DummyCategoricalSplit(feature: Int, kind : FeatureType) extends Split(feature, Double.MaxValue, kind, List()) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5f9aad0de2f65..4e68611d2be9e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ +import scala.collection.mutable class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -50,7 +51,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) assert(splits(0).length==99) @@ -58,12 +59,58 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { //println(splits(1)(98)) } + test("split and bin calculation for categorical variables"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + assert(splits.length==2) + assert(bins.length==2) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(0)(0)) + println(splits(0)(1)) + println(bins(0)(0)) + println(splits(1)(0)) + println(splits(1)(1)) + println(bins(1)(0)) + } + + test("split and bin calculations for categorical variables with no sample for one category"){ + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + assert(splits.length==2) + assert(bins.length==2) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(0)(0)) + println(splits(0)(1)) + println(splits(0)(2)) + println(bins(0)(0)) + println(bins(0)(1)) + println(bins(0)(2)) + println(splits(1)(0)) + println(splits(1)(1)) + println(splits(1)(2)) + println(bins(1)(0)) + println(bins(1)(1)) + println(bins(0)(2)) + println(bins(0)(3)) + } + + //TODO: Test max feature value > num bins + + test("stump with fixed label 0 for Gini"){ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -73,15 +120,13 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { strategy.numBins = 100 val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) - println("here") assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Gini"){ @@ -89,7 +134,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Gini,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -103,10 +148,10 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) + } @@ -115,7 +160,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -129,10 +174,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Entropy"){ @@ -140,7 +184,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification,Entropy,3,100) - val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(splits(0).length==99) assert(bins.length==2) @@ -154,10 +198,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) assert(0==bestSplits(0)._2.gain) - assert(10==bestSplits(0)._2.leftSamples) assert(0==bestSplits(0)._2.leftImpurity) - assert(990==bestSplits(0)._2.rightSamples) assert(0==bestSplits(0)._2.rightImpurity) + assert(0.01==bestSplits(0)._2.predict) } @@ -184,4 +227,16 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPoints() : Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0,Array(0.0,1.0)) + } else { + arr(i) = new LabeledPoint(0.0,Array(1.0,0.0)) + } + } + arr + } + }