Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19591][ML][MLlib] Add sample weights to decision trees #21632

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ object TestingUtils {
/**
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,37 @@ abstract class Classifier[
* @note Throws `SparkException` if any label is a non-integer or is negative
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
validateNumClasses(numClasses)
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
validateLabel(label, numClasses)
LabeledPoint(label, features)
}
}

/**
* Validates that number of classes is greater than zero.
*
* @param numClasses Number of classes label can take.
*/
protected def validateNumClasses(numClasses: Int): Unit = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
}

/**
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*
* @param label The label to validate.
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
*/
protected def validateLabel(label: Double, numClasses: Int): Unit = {
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}

/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset

import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -97,29 +103,44 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
instr.logNumClasses(numClasses)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
validateNumClasses(numClasses)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
validateLabel(label, numClasses)
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures, numClasses)

instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
Expand All @@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
val instances = data.map(_.toInstance)
instr.logPipelineStage(this)
instr.logDataset(data)
instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
Expand Down Expand Up @@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (

/**
* Construct a decision tree classification model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.functions.{col, udf}

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

Expand All @@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])

val numFeatures = trees.head.numFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}

private[spark] def toInstance(weight: Double): Instance = {
Instance(label, weight, features)
}

private[spark] def toInstance: Instance = {
Instance(label, 1.0, features)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
Expand All @@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType


/**
Expand Down Expand Up @@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures)

instr.logPipelineStage(this)
instr.logDataset(oldDataset)
instr.logDataset(instances)
instr.logParams(this, params: _*)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeRegressionModel]
Expand All @@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
instr.logDataset(data)
instr.logParams(this, params: _*)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val instances = data.map(_.toInstance)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
Expand Down Expand Up @@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
* Decision tree (Wikipedia)</a> model for regression.
* It supports both continuous and categorical features.
*
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
Expand All @@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (

/**
* Construct a decision tree regression model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.functions.{col, udf}

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
Expand Down Expand Up @@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)

val instances = extractLabeledPoints(dataset).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = trees.head.numFeatures
Expand Down
Loading