Skip to content

Commit

Permalink
[SPARK-2851] [mllib] DecisionTree Python consistency update
Browse files Browse the repository at this point in the history
Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs).

Added factory classes for Algo and Impurity, but made private[mllib].

CC: mengxr dorx  Please let me know if there are other changes which would help with API consistency---thanks!

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes apache#1798 from jkbradley/dt-python-consistency and squashes the following commits:

6f7edf8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency
a0d7dbe [Joseph K. Bradley] DecisionTree: In Java-friendly train* methods, changed to use JavaRDD instead of RDD.
ee1d236 [Joseph K. Bradley] DecisionTree API updates: * Removed train() function in Python API (tree.py) ** Removed corresponding function in Scala/Java API (the ones taking basic types)
00f820e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-python-consistency
fe6dbfa [Joseph K. Bradley] removed unnecessary imports
e358661 [Joseph K. Bradley] DecisionTree API change: * Added 6 static train methods to match Python API, but without default arguments (but with Python default args noted in docs).
c699850 [Joseph K. Bradley] a few doc comments
eaf84c0 [Joseph K. Bradley] Added DecisionTree static train() methods API to match Python, but without default parameters
  • Loading branch information
jkbradley authored and mengxr committed Aug 7, 2014
1 parent ffd1f59 commit 47ccd5e
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.correlation.CorrelationNames
Expand Down Expand Up @@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable {

val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)

val algo: Algo = algoStr match {
case "classification" => Classification
case "regression" => Regression
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
}
val impurity: Impurity = impurityStr match {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
}
val algo = Algo.fromString(algoStr)
val impurity = Impurities.fromString(impurityStr)

val strategy = new Strategy(
algo = algo,
Expand Down
151 changes: 124 additions & 27 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

package org.apache.spark.mllib.tree

import org.apache.spark.api.java.JavaRDD

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging {
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
* For regression, labels are real numbers.
Expand All @@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging {
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The method supports binary classification and regression. For the
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
Expand All @@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging {
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The method supports binary classification and regression. For the
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
Expand All @@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging {
}

/**
* Method to train a decision tree model where the instances are represented as an RDD of
* (label, features) pairs. The decision tree method supports binary classification and
* regression. For the binary classification, the label for each instance should either be 0 or
* 1 to denote the two classes. The method also supports categorical features inputs where the
* number of categories can specified using the categoricalFeaturesInfo option.
* Method to train a decision tree model.
* The method supports binary and multiclass classification and regression.
*
* Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
* and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
* is recommended to clearly separate classification and regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* For classification, labels should take values {0, 1, ..., numClasses-1}.
Expand All @@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging {
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
* the number of discrete values they take. For example,
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @return DecisionTreeModel that can be used for prediction
*/
def train(
Expand All @@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging {
new DecisionTree(strategy).train(input)
}

/**
* Method to train a decision tree model for binary or multiclass classification.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
* @param numClassesForClassification number of classes for classification.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 4)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @return DecisionTreeModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
numClassesForClassification: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
categoricalFeaturesInfo)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
*/
def trainClassifier(
input: JavaRDD[LabeledPoint],
numClassesForClassification: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}

/**
* Method to train a decision tree model for regression.
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels are real numbers.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param impurity Criterion used for information gain calculation.
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* (suggested value: 4)
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @return DecisionTreeModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
}

/**
* Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
*/
def trainRegressor(
input: JavaRDD[LabeledPoint],
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}


private val InvalidBinIndex = -1

/**
Expand Down Expand Up @@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging {
* Categorical features:
* For each feature, there is 1 bin per split.
* Splits and bins are handled in 2 ways:
* (a) For multiclass classification with a low-arity feature
* (a) "unordered features"
* For multiclass classification with a low-arity feature
* (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
* the feature is split based on subsets of categories.
* There are 2^(maxFeatureValue - 1) - 1 splits.
* (b) For regression and binary classification,
* There are math.pow(2, maxFeatureValue - 1) - 1 splits.
* (b) "ordered features"
* For regression and binary classification,
* and for multiclass classification with a high-arity feature,
* there is one split per category.
* Categorical case (a) features are called unordered features.
* Other cases are called ordered features.
* there is one bin per category.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value

private[mllib] def fromString(name: String): Algo = name match {
case "classification" => Classification
case "regression" => Regression
case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impurity

/**
* Factory for Impurity instances.
*/
private[mllib] object Impurities {

def fromString(name: String): Impurity = name match {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
}

}
50 changes: 15 additions & 35 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class DecisionTree(object):
"""

@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for classification.
Expand All @@ -150,12 +150,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
return DecisionTree.train(data, "classification", numClasses,
categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
sc = data.context
dataBytes = _get_unmangled_labeled_point_rdd(data)
categoricalFeaturesInfoJMap = \
MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, "classification",
numClasses, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)

@staticmethod
def trainRegressor(data, categoricalFeaturesInfo={},
def trainRegressor(data, categoricalFeaturesInfo,
impurity="variance", maxDepth=4, maxBins=100):
"""
Train a DecisionTreeModel for regression.
Expand All @@ -173,42 +181,14 @@ def trainRegressor(data, categoricalFeaturesInfo={},
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
return DecisionTree.train(data, "regression", 0,
categoricalFeaturesInfo,
impurity, maxDepth, maxBins)

@staticmethod
def train(data, algo, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins=100):
"""
Train a DecisionTreeModel for classification or regression.
:param data: Training data: RDD of LabeledPoint.
For classification, labels are integers
{0,1,...,numClasses}.
For regression, labels are real numbers.
:param algo: "classification" or "regression"
:param numClasses: Number of classes for classification.
:param categoricalFeaturesInfo: Map from categorical feature index
to number of categories.
Any feature not in this map
is treated as continuous.
:param impurity: For classification: "entropy" or "gini".
For regression: "variance".
:param maxDepth: Max depth of tree.
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
:return: DecisionTreeModel
"""
sc = data.context
dataBytes = _get_unmangled_labeled_point_rdd(data)
categoricalFeaturesInfoJMap = \
MapConverter().convert(categoricalFeaturesInfo,
sc._gateway._gateway_client)
model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
dataBytes._jrdd, algo,
numClasses, categoricalFeaturesInfoJMap,
dataBytes._jrdd, "regression",
0, categoricalFeaturesInfoJMap,
impurity, maxDepth, maxBins)
dataBytes.unpersist()
return DecisionTreeModel(sc, model)
Expand Down

0 comments on commit 47ccd5e

Please sign in to comment.