Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9478] [ml] Add class weights to Random Forest #13851

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)

@Since("2.0.0")
override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)

override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Expand Down Expand Up @@ -119,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
subsamplingRate = 1.0, getClassWeights)
}

@Since("1.4.1")
Expand All @@ -129,7 +132,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
/** Accessor for supported impurities: entropy, gini */
/** Accessor for supported impurities: entropy, gini, weightedgini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities

Expand Down Expand Up @@ -168,7 +171,7 @@ class DecisionTreeClassificationModel private[ml] (
}

override protected def predictRaw(features: Vector): Vector = {
Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
Vectors.dense(rootNode.predictImpl(features).impurityStats.weightedStats.clone())
}

override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,17 @@ class RandomForestClassifier @Since("1.4.0") (
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)

@Since("2.0.0")
override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)

override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
getOldImpurity, getSubsamplingRate, getClassWeights)

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down Expand Up @@ -195,7 +199,8 @@ class RandomForestClassificationModel private[ml] (
// Ignore the tree weights since all are 1.0 for now.
val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach { tree =>
val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
val classCounts: Array[Double] =
tree.rootNode.predictImpl(features).impurityStats.weightedStats
val total = classCounts.sum
if (total != 0) {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
subsamplingRate = 1.0)
subsamplingRate = 1.0, classWeights = Array())
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression,
getOldImpurity, getSubsamplingRate, classWeights = Array())

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._



/**
* DecisionTree statistics aggregator for a node.
* This holds a flat array of statistics for a set of (features, bins)
Expand All @@ -38,6 +37,7 @@ private[spark] class DTStatsAggregator(
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case WeightedGini => new WeightedGiniAggregator(metadata.numClasses, metadata.classWeights)
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ private[spark] class DecisionTreeMetadata(
val minInstancesPerNode: Int,
val minInfoGain: Double,
val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable {
val numFeaturesPerNode: Int,
val classWeights: Array[Double]) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

Expand Down Expand Up @@ -207,7 +208,8 @@ private[spark] object DecisionTreeMetadata extends Logging {
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode,
strategy.classWeights)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,15 @@ private[spark] object RandomForest extends Logging {
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()

val leftWeight = leftCount / totalCount.toDouble
val rightWeight = rightCount / totalCount.toDouble
// Weighted count is equivalent to normal count using Gini or Entropy impurity
// where the class weights are assumed to be uniform
val leftWeightedCount = leftImpurityCalculator.weightedCount
val rightWeightedCount = rightImpurityCalculator.weightedCount

val totalWeightedCount = leftWeightedCount + rightWeightedCount

val leftWeight = leftWeightedCount / totalWeightedCount.toDouble
val rightWeight = rightWeightedCount / totalWeightedCount.toDouble

val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity

Expand Down
28 changes: 24 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,17 @@ private[ml] object DecisionTreeModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}

// Get class weights to construct ImpurityCalculator. This value
// is ignored unless the impurity is WeightedGini
val classWeights: Array[Double] = {
val classWeightsJson: JValue = metadata.getParamValue("classWeights")
compact(render(classWeightsJson)).split("\\[|,|\\]")
.filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
}

val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
buildTreeFromNodes(data.collect(), impurityType)
buildTreeFromNodes(data.collect(), impurityType, classWeights)
}

/**
Expand All @@ -353,7 +361,8 @@ private[ml] object DecisionTreeModelReadWrite {
* @param impurityType Impurity type for this tree
* @return Root node of reconstructed tree
*/
def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
classWeights: Array[Double]): Node = {
// Load all nodes, sorted by ID.
val nodes = data.sortBy(_.id)
// Sanity checks; could remove
Expand All @@ -365,7 +374,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent nodes.
val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData =>
val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
val impurityStats = ImpurityCalculator.getCalculator(impurityType,
n.impurityStats, classWeights)
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
Expand Down Expand Up @@ -437,6 +447,15 @@ private[ml] object EnsembleModelReadWrite {
Param.jsonDecode[String](compact(render(impurityJson)))
}

// Get class weights to construct ImpurityCalculator. This value
// is ignored unless the impurity is WeightedGini
val classWeights: Array[Double] = {
val classWeightsJson: JValue = metadata.getParamValue("classWeights")
val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]")
.filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
classWeightsArray
}

val treesMetadataPath = new Path(path, "treesMetadata").toString
val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
.select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
Expand All @@ -454,7 +473,8 @@ private[ml] object EnsembleModelReadWrite {
val rootNodesRDD: RDD[(Int, Node)] =
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
case (treeID: Int, nodeData: Iterable[NodeData]) =>
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray,
impurityType, classWeights)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
(metadata, treesMetadata.zip(rootNodes), treesWeights)
Expand Down
87 changes: 78 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance, WeightedGini}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

Expand Down Expand Up @@ -155,7 +155,31 @@ private[ml] trait DecisionTreeParams extends PredictorParams
*/
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)

/** (private[ml]) Create a Strategy instance to use with the old API. */
/** (private[ml]) Create a Strategy instance. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double,
classWeights: Array[Double]): OldStrategy = {
val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = classWeights
strategy
}

/** (private[ml]) Create a Strategy whose interface is compatible with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
Expand All @@ -174,6 +198,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
strategy.subsamplingRate = subsamplingRate
strategy.classWeights = Array(1.0, 1.0)
strategy
}
}
Expand All @@ -185,7 +210,7 @@ private[ml] trait TreeClassifierParams extends Params {

/**
* Criterion used for information gain calculation (case-insensitive).
* Supported: "entropy" and "gini".
* Supported: "entropy", "gini" and "weightedgini".
* (default = gini)
* @group param
*/
Expand All @@ -194,19 +219,34 @@ private[ml] trait TreeClassifierParams extends Params {
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))

setDefault(impurity -> "gini")
/**
* An array that stores the weights of class labels. All elements must be non-negative.
* (default = Array(1.0, 1.0))
* @group Param
*/
final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
" that stores the weights of class labels. All elements must be non-negative.")

setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0))

/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)

/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase

/** @group SetParam */
def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)

/** @group GetParam */
final def getClassWeights: Array[Double] = $(classWeights)

/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "entropy" => OldEntropy
case "gini" => OldGini
case "weightedgini" => WeightedGini
case _ =>
// Should never happen because of check in setter method.
throw new RuntimeException(
Expand All @@ -217,7 +257,8 @@ private[ml] trait TreeClassifierParams extends Params {

private[ml] object TreeClassifierParams {
// These options should be lowercase.
final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini")
.map(_.toLowerCase)
}

private[ml] trait DecisionTreeClassifierParams
Expand All @@ -239,14 +280,29 @@ private[ml] trait TreeRegressorParams extends Params {
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
(value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))

setDefault(impurity -> "variance")
/**
* An array that stores the weights of class labels. This parameter will be ignored in
* regression trees.
* (default = Array())
* @group Param
*/
final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
" that stores the weights of class labels. All elements must be non-negative.")

setDefault(impurity -> "variance", classWeights -> Array())

/** @group setParam */
def setImpurity(value: String): this.type = set(impurity, value)

/** @group getParam */
final def getImpurity: String = $(impurity).toLowerCase

/** @group SetParam */
def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)

/** @group GetParam */
final def getClassWeights: Array[Double] = $(classWeights)

/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
Expand Down Expand Up @@ -312,8 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
oldImpurity: OldImpurity,
classWeights: Array[Double]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
oldImpurity, getSubsamplingRate, classWeights)
}

private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int,
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
oldImpurity, getSubsamplingRate, Array(1.0, 1.0))
}
}

Expand Down Expand Up @@ -455,7 +522,9 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
private[ml] def getOldBoostingStrategy(
categoricalFeatures: Map[Int, Int],
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2,
oldAlgo, OldVariance, Array(1.0, 1.0))

// NOTE: The old API does not support "seed" so we ignore it.
new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
}
Expand Down
Loading