diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fd0b9556c7d54..ba7ccd8ce4b8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames @@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable { val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val algo: Algo = algoStr match { - case "classification" => Classification - case "regression" => Regression - case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") - } - val impurity: Impurity = impurityStr match { - case "gini" => Gini - case "entropy" => Entropy - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") - } + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) val strategy = new Strategy( algo = algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36cf..c8a865659682f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,14 +17,18 @@ package org.apache.spark.mllib.tree +import org.apache.spark.api.java.JavaRDD + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging { * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ def train( @@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + private val InvalidBinIndex = -1 /** @@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging { * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: - * (a) For multiclass classification with a low-arity feature + * (a) "unordered features" + * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are 2^(maxFeatureValue - 1) - 1 splits. - * (b) For regression and binary classification, + * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * (b) "ordered features" + * For regression and binary classification, * and for multiclass classification with a high-arity feature, - * there is one split per category. - - * Categorical case (a) features are called unordered features. - * Other cases are called ordered features. + * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 79a01f58319e8..0ef9c6181a0a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value + + private[mllib] def fromString(name: String): Algo = name match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala new file mode 100644 index 0000000000000..9a6452aa13a61 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Factory for Impurity instances. + */ +private[mllib] object Impurities { + + def fromString(name: String): Impurity = name match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") + } + +} diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 2518001ea0b93..e1a4671709b7d 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -131,7 +131,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -150,12 +150,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -173,42 +181,14 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model)