Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into mllib-stats-api-c…
Browse files Browse the repository at this point in the history
…heck
  • Loading branch information
jkbradley committed Aug 17, 2014
2 parents cf70b07 + 73ab7f1 commit c8c20dc
Show file tree
Hide file tree
Showing 9 changed files with 615 additions and 630 deletions.
878 changes: 386 additions & 492 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impl

import scala.collection.mutable

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD


/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
val numClasses: Int,
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

def isClassification: Boolean = numClasses >= 2

def isMulticlass: Boolean = numClasses > 2

def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)

def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)

def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)

}

private[tree] object DecisionTreeMetadata {

def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {

val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClassesForClassification
case Regression => 0
}

val maxBins = math.min(strategy.maxBins, numExamples).toInt
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)

val unorderedFeatures = new mutable.HashSet[Int]()
if (numClasses > 2) {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
if (k - 1 < log2MaxBinsp1) {
// Note: The above check is equivalent to checking:
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
unorderedFeatures.add(f)
} else {
// TODO: Allow this case, where we simply will know nothing about some categories?
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
}
}
} else {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
}
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
strategy.impurity, strategy.quantileCalculationStrategy)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -48,50 +47,35 @@ private[tree] object TreePoint {
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param strategy DecisionTree training info, used for dataset metadata.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
strategy: Strategy,
bins: Array[Array[Bin]]): RDD[TreePoint] = {
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
input.map { x =>
TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
strategy.categoricalFeaturesInfo)
TreePoint.labeledPointToTreePoint(x, bins, metadata)
}
}

/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
isMulticlassClassification: Boolean,
bins: Array[Array[Bin]],
categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
metadata: DecisionTreeMetadata): TreePoint = {

val numFeatures = labeledPoint.features.size
val numBins = bins(0).size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
val featureInfo = categoricalFeaturesInfo.get(featureIndex)
val isFeatureContinuous = featureInfo.isEmpty
if (isFeatureContinuous) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
bins, categoricalFeaturesInfo)
} else {
val featureCategories = featureInfo.get
val isSpaceSufficientForAllCategoricalSplits
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
val isUnorderedFeature =
isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
isUnorderedFeature, bins, categoricalFeaturesInfo)
}
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
featureIndex += 1
}

Expand Down
18 changes: 14 additions & 4 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* Used for "binning" the features bins for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
* lowSplit.threshold < featureValue <= highSplit.threshold.
*
* For ordered categorical features, there is a 1-1-1 correspondence between
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for binary classification
* @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predictIfLeaf(features)
topNode.predict(features)
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ class Node (

/**
* predict value if node is not leaf
* @param feature feature value
* @param features feature value
* @return predicted value
*/
def predictIfLeaf(feature: Vector) : Double = {
def predict(features: Vector) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
if (features(split.get.feature) <= split.get.threshold) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
} else {
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
if (split.get.categories.contains(features(split.get.feature))) {
leftNode.get.predict(features)
} else {
rightNode.get.predictIfLeaf(feature)
rightNode.get.predict(features)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
* :: DeveloperApi ::
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param threshold Threshold for continuous feature.
* Split left if feature <= threshold, else right.
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
* @param categories Split left if categorical feature value is in this set, else right.
*/
@DeveloperApi
case class Split(
Expand Down
Loading

0 comments on commit c8c20dc

Please sign in to comment.