Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9612][ML] Add instance weight support for GBTs #25926

Closed
wants to merge 9 commits into from

Conversation

zhengruifeng
Copy link
Contributor

@zhengruifeng zhengruifeng commented Sep 25, 2019

What changes were proposed in this pull request?

add weight support for GBTs by sampling data before passing it to trees and then passing weights to trees

in summary:
1, add setters of minWeightFractionPerNode & weightCol
2, update input types in private methods from RDD[LabeledPoint] to RDD[Instance]:
DecisionTreeRegressor.train, GradientBoostedTrees.run, GradientBoostedTrees.runWithValidation, GradientBoostedTrees.computeInitialPredictionAndError, GradientBoostedTrees.computeError,
GradientBoostedTrees.evaluateEachIteration, GradientBoostedTrees.boost, GradientBoostedTrees.updatePredictionError
3, add new private method GradientBoostedTrees.computeError(data, predError) to compute average error, since original predError.values.mean() do not take weights into account.
4, add new tests

Why are the changes needed?

GBTs should support sample weights like other algs

Does this PR introduce any user-facing change?

yes, new setters are added

How was this patch tested?

existing & added testsuites

@dongjoon-hyun
Copy link
Member

Wow, @zhengruifeng . This is really a long standing JIRA. 👍

@SparkQA
Copy link

SparkQA commented Sep 25, 2019

Test build #111345 has finished for PR 25926 at commit e1b3aa2.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Sep 26, 2019

Test build #111380 has finished for PR 25926 at commit 9ea6e00.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.


MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel,
GBTRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to DecisionTreeRegressorSuite, I need to limit the number of trees and loose the tolerance eps(0.99 -> 0.95) to pass the cases.
I wonder if it is due to accumulated errors among trees.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, will need to take a closer look...

@zhengruifeng
Copy link
Contributor Author

cc @srowen @imatiach-msft

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update comment to

By default the weightCol is not set, so all instances have weight 1.0.


val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty

// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val convert2LabeledPoint = (dataset: Dataset[_]) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message here was much nicer:

GBTClassifier currently only supports binary classification.

than the new one in extractInstances. Perhaps it would be nicer to keep this custom error message, or pass some part of the message to the extractInstances method.

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update comment, same as above

@@ -68,7 +68,7 @@ class GradientBoostedTrees private[spark] (
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
val (trees, treeWeights) = NewGBT.run(input.map { point =>
NewLabeledPoint(point.label, point.features.asML)
NewLabeledPoint(point.label, point.features.asML).toInstance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we create an instance directly from label, features - seems a bit too much to create a temporary LabeledPoint object that's unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is better to directly create Instance

Copy link
Contributor

@imatiach-msft imatiach-msft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some initial comments, will need to take a closer look into how the weights are used in the boosting more

override protected def train(
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can this line be moved above where it is used:

val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably good to move boostingStrategy below as well

val error = loss.computeError(pred, lp.label)
data.map { case Instance(label, _, features) =>
val pred = updatePrediction(features, 0.0, initTree, initTreeWeight)
val error = loss.computeError(pred, label)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm shouldn't the loss be weighted by the weight column value here? seems a bit strange to ignore the weight column here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, reading some of the other code this looks like unweighted error. That seems very confusing. I think we could improve this code structure a bit more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the problem with this returning weighted error and getting rid of the computeError function?

val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
val newError = loss.computeError(newPred, lp.label)
data.zip(predictionAndError).map {
case (Instance(label, _, features), (pred, _)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here - it seems like we are ignoring the weight column but intuitively it seems like it should be included, could you explain the reasoning behind it being excluded here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, reading some of the other code this looks like unweighted error. That seems very confusing. I think we could improve this code structure a bit more.

loss.computeError(predicted, lp.label)
}.mean()
(loss.computeError(predicted, label) * weight, weight)
}.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spacing of { after treeReduce

prediction * localTreeWeights(idx)
val numTrees = trees.length

val (errSum, weightSum) = remappedData.mapPartitions { iter =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to understand this code - why are the trees broadcast here but the treeWeights are not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be clear the previous code is doing this as well, I just don't understand why the treeWeights aren't broadcast either

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this place, I just followed preivous impl. I am neutral on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

@@ -299,26 +317,25 @@ private[spark] object GradientBoostedTrees extends Logging {
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight

var predError: RDD[(Double, Double)] =
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice if we could checkpoint the weighted instead of unweighted prediction error, which ties into the earlier comment on why methods like computeInitialPredictionAndError can't return the weighted prediction error

@zhengruifeng
Copy link
Contributor Author

@imatiach-msft Thanks for reviewing!
As to the points on weighted prediction error:
After previous discussions, we should sample the data without weights, and pass the weights into the base model (decision tree).
So the input passed to a decsion tree, should contain the label (unweighted prediction error) and the instance weights (which will be also used in minWeightFractionPerNode). In this way, I guess we do not need to cache weighted error.
Moreover, the code predError.values.mean() with weighted predError is not equal to the average weighted error in this PR.

PS: If I recall correctly, XGBoost pass weighted gradients and hessions into base learner. It use minimum hession (min_child_weight) to limit tree growth, which is quite different from MLLIB.

@SparkQA
Copy link

SparkQA commented Oct 8, 2019

Test build #111878 has finished for PR 25926 at commit f0d890a.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 8, 2019

Test build #111887 has finished for PR 25926 at commit 3c07243.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* @param predError Prediction and error.
* @return Measure of model error on data
*/
def computeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change the name to computeWeightedError to make that clear, since the above methods are also computing error but unweighted

@@ -299,26 +317,25 @@ private[spark] object GradientBoostedTrees extends Logging {
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight

var predError: RDD[(Double, Double)] =
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are going to keep the checkpointing of unweighted error as opposed to weighted error, then it would be nice to specify that in the name of the checkpointer:
predUnweigtedErrorCheckpointer
or alternatively add a comment to make that clear:

    // Note: this is checkpointing the unweighted error
    predErrorCheckpointer.update(predError) 

@imatiach-msft
Copy link
Contributor

@zhengruifeng I see, it seems a bit confusing to have a lot of references to error as both weighted and unweighted inside the same function - for example I would prefer to only checkpoint weighted error, and computeError function doesn't suggest anything in the name about it specifically making the error weighted - but as long as we have good documentation and variable names in the code to help distinguish which variable is for what I think it should be fine

Copy link
Contributor

@imatiach-msft imatiach-msft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great feature! LGTM!

@SparkQA
Copy link

SparkQA commented Oct 16, 2019

Test build #112147 has finished for PR 25926 at commit 02457a7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

@imatiach-msft Thanks for reviewing and your previous works on decision tree supporting sample weights!

@srowen
Copy link
Member

srowen commented Oct 16, 2019

@zhengruifeng @imatiach-msft did you have any other changes to make? there may still be some open comments, not sure if they were addressed

@zhengruifeng
Copy link
Contributor Author

@srowen I think the only place may need more dicussion is that I need to loose the tolerance in test suites (compared with DecisionTreeSuites).

@imatiach-msft How do you think about it?

@zhengruifeng
Copy link
Contributor Author

I manually tested this PR in repl in the past days, with some datasets in /data/mllib, set relative params to in normal ranges (for example weight in [1.0, 10.0], not extreme values (0.01, 1000) in the testsuits), and the results looked fine.

I think the error is accumulated among DecisionTrees, so I need to loose the tolerance in test suites (compared with DecisionTreeSuites).

I will merge this PR this week if no more comments.

@SparkQA
Copy link

SparkQA commented Oct 23, 2019

Test build #112539 has finished for PR 25926 at commit 3000397.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

Merged to master, thanks @imatiach-msft @srowen for reviewing!

@zhengruifeng zhengruifeng deleted the gbt_add_weight branch October 25, 2019 05:51
srowen pushed a commit that referenced this pull request Dec 9, 2019
### What changes were proposed in this pull request?
add ```setWeightCol``` and ```setMinWeightFractionPerNode``` in Python side of ```GBTClassifier``` and ```GBTRegressor```

### Why are the changes needed?
#25926 added ```setWeightCol``` and ```setMinWeightFractionPerNode``` in GBTs on scala side. This PR will add ```setWeightCol``` and ```setMinWeightFractionPerNode``` in GBTs on python side

### Does this PR introduce any user-facing change?
Yes

### How was this patch tested?
doc test

Closes #26774 from huaxingao/spark-30146.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
zhengruifeng added a commit that referenced this pull request Jan 6, 2020
### What changes were proposed in this pull request?
1, fix `BaggedPoint.convertToBaggedRDD` when `subsamplingRate < 1.0`
2, reorg `RandomForest.runWithMetadata` btw

### Why are the changes needed?
In GBT, Instance weights will be discarded if subsamplingRate<1

1, `baggedPoint: BaggedPoint[TreePoint]` is used in the tree growth to find best split;
2, `BaggedPoint[TreePoint]` contains two weights:
```scala
class BaggedPoint[Datum](val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double = 1.0)
class TreePoint(val label: Double, val binnedFeatures: Array[Int], val weight: Double)
```
3, only the var `sampleWeight` in `BaggedPoint` is used, the var `weight` in `TreePoint` is never used in finding splits;
4, The method  `BaggedPoint.convertToBaggedRDD` was changed in #21632, it was only for decisiontree, so only the following code path was changed;
```
if (numSubsamples == 1 && subsamplingRate == 1.0) {
        convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
      }
```
5, In #25926, I made GBT support weights, but only test it with default `subsamplingRate==1`.
GBT with `subsamplingRate<1` will convert treePoints to baggedPoints via
```scala
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
```
in which the orignial weights from `weightCol` will be discarded and all `sampleWeight` are assigned default 1.0;

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
updated testsuites

Closes #27070 from zhengruifeng/gbt_sampling.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants