-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9612][ML] Add instance weight support for GBTs #25926
Conversation
Wow, @zhengruifeng . This is really a long standing JIRA. 👍 |
Test build #111345 has finished for PR 25926 at commit
|
Test build #111380 has finished for PR 25926 at commit
|
|
||
MLTestingUtils.testArbitrarilyScaledWeights[GBTRegressionModel, | ||
GBTRegressor](df.as[LabeledPoint], estimator, | ||
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.95)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared to DecisionTreeRegressorSuite
, I need to limit the number of trees and loose the tolerance eps(0.99 -> 0.95) to pass the cases.
I wonder if it is due to accumulated errors among trees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, will need to take a closer look...
/** | ||
* Sets the value of param [[weightCol]]. | ||
* If this is not set or empty, we treat all instance weights as 1.0. | ||
* Default is not set, so all instances have weight one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update comment to
By default the weightCol is not set, so all instances have weight 1.0.
|
||
val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty | ||
|
||
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports | ||
// 2 classes now. This lets us provide a more precise error message. | ||
val convert2LabeledPoint = (dataset: Dataset[_]) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the error message here was much nicer:
GBTClassifier currently only supports binary classification.
than the new one in extractInstances. Perhaps it would be nicer to keep this custom error message, or pass some part of the message to the extractInstances method.
/** | ||
* Sets the value of param [[weightCol]]. | ||
* If this is not set or empty, we treat all instance weights as 1.0. | ||
* Default is not set, so all instances have weight one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update comment, same as above
@@ -68,7 +68,7 @@ class GradientBoostedTrees private[spark] ( | |||
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { | |||
val algo = boostingStrategy.treeStrategy.algo | |||
val (trees, treeWeights) = NewGBT.run(input.map { point => | |||
NewLabeledPoint(point.label, point.features.asML) | |||
NewLabeledPoint(point.label, point.features.asML).toInstance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we create an instance directly from label, features - seems a bit too much to create a temporary LabeledPoint object that's unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is better to directly create Instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some initial comments, will need to take a closer look into how the weights are used in the boosting more
override protected def train( | ||
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr => | ||
val categoricalFeatures: Map[Int, Int] = | ||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can this line be moved above where it is used:
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably good to move boostingStrategy below as well
val error = loss.computeError(pred, lp.label) | ||
data.map { case Instance(label, _, features) => | ||
val pred = updatePrediction(features, 0.0, initTree, initTreeWeight) | ||
val error = loss.computeError(pred, label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm shouldn't the loss be weighted by the weight column value here? seems a bit strange to ignore the weight column here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, reading some of the other code this looks like unweighted error. That seems very confusing. I think we could improve this code structure a bit more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would be the problem with this returning weighted error and getting rid of the computeError function?
val newPred = updatePrediction(lp.features, pred, tree, treeWeight) | ||
val newError = loss.computeError(newPred, lp.label) | ||
data.zip(predictionAndError).map { | ||
case (Instance(label, _, features), (pred, _)) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here - it seems like we are ignoring the weight column but intuitively it seems like it should be included, could you explain the reasoning behind it being excluded here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, reading some of the other code this looks like unweighted error. That seems very confusing. I think we could improve this code structure a bit more.
loss.computeError(predicted, lp.label) | ||
}.mean() | ||
(loss.computeError(predicted, label) * weight, weight) | ||
}.treeReduce{ case ((err1, weight1), (err2, weight2)) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spacing of { after treeReduce
prediction * localTreeWeights(idx) | ||
val numTrees = trees.length | ||
|
||
val (errSum, weightSum) = remappedData.mapPartitions { iter => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to understand this code - why are the trees broadcast here but the treeWeights are not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to be clear the previous code is doing this as well, I just don't understand why the treeWeights aren't broadcast either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this place, I just followed preivous impl. I am neutral on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
@@ -299,26 +317,25 @@ private[spark] object GradientBoostedTrees extends Logging { | |||
baseLearners(0) = firstTreeModel | |||
baseLearnerWeights(0) = firstTreeWeight | |||
|
|||
var predError: RDD[(Double, Double)] = | |||
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) | |||
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) | |||
predErrorCheckpointer.update(predError) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice if we could checkpoint the weighted instead of unweighted prediction error, which ties into the earlier comment on why methods like computeInitialPredictionAndError can't return the weighted prediction error
@imatiach-msft Thanks for reviewing! PS: If I recall correctly, XGBoost pass weighted gradients and hessions into base learner. It use minimum hession ( |
9ea6e00
to
038ff58
Compare
Test build #111878 has finished for PR 25926 at commit
|
Test build #111887 has finished for PR 25926 at commit
|
* @param predError Prediction and error. | ||
* @return Measure of model error on data | ||
*/ | ||
def computeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe change the name to computeWeightedError to make that clear, since the above methods are also computing error but unweighted
@@ -299,26 +317,25 @@ private[spark] object GradientBoostedTrees extends Logging { | |||
baseLearners(0) = firstTreeModel | |||
baseLearnerWeights(0) = firstTreeWeight | |||
|
|||
var predError: RDD[(Double, Double)] = | |||
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) | |||
var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) | |||
predErrorCheckpointer.update(predError) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are going to keep the checkpointing of unweighted error as opposed to weighted error, then it would be nice to specify that in the name of the checkpointer:
predUnweigtedErrorCheckpointer
or alternatively add a comment to make that clear:
// Note: this is checkpointing the unweighted error
predErrorCheckpointer.update(predError)
@zhengruifeng I see, it seems a bit confusing to have a lot of references to error as both weighted and unweighted inside the same function - for example I would prefer to only checkpoint weighted error, and computeError function doesn't suggest anything in the name about it specifically making the error weighted - but as long as we have good documentation and variable names in the code to help distinguish which variable is for what I think it should be fine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great feature! LGTM!
Test build #112147 has finished for PR 25926 at commit
|
@imatiach-msft Thanks for reviewing and your previous works on decision tree supporting sample weights! |
@zhengruifeng @imatiach-msft did you have any other changes to make? there may still be some open comments, not sure if they were addressed |
@srowen I think the only place may need more dicussion is that I need to loose the tolerance in test suites (compared with DecisionTreeSuites). @imatiach-msft How do you think about it? |
02457a7
to
3000397
Compare
I manually tested this PR in repl in the past days, with some datasets in I think the error is accumulated among I will merge this PR this week if no more comments. |
Test build #112539 has finished for PR 25926 at commit
|
Merged to master, thanks @imatiach-msft @srowen for reviewing! |
### What changes were proposed in this pull request? add ```setWeightCol``` and ```setMinWeightFractionPerNode``` in Python side of ```GBTClassifier``` and ```GBTRegressor``` ### Why are the changes needed? #25926 added ```setWeightCol``` and ```setMinWeightFractionPerNode``` in GBTs on scala side. This PR will add ```setWeightCol``` and ```setMinWeightFractionPerNode``` in GBTs on python side ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? doc test Closes #26774 from huaxingao/spark-30146. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
### What changes were proposed in this pull request? 1, fix `BaggedPoint.convertToBaggedRDD` when `subsamplingRate < 1.0` 2, reorg `RandomForest.runWithMetadata` btw ### Why are the changes needed? In GBT, Instance weights will be discarded if subsamplingRate<1 1, `baggedPoint: BaggedPoint[TreePoint]` is used in the tree growth to find best split; 2, `BaggedPoint[TreePoint]` contains two weights: ```scala class BaggedPoint[Datum](val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double = 1.0) class TreePoint(val label: Double, val binnedFeatures: Array[Int], val weight: Double) ``` 3, only the var `sampleWeight` in `BaggedPoint` is used, the var `weight` in `TreePoint` is never used in finding splits; 4, The method `BaggedPoint.convertToBaggedRDD` was changed in #21632, it was only for decisiontree, so only the following code path was changed; ``` if (numSubsamples == 1 && subsamplingRate == 1.0) { convertToBaggedRDDWithoutSampling(input, extractSampleWeight) } ``` 5, In #25926, I made GBT support weights, but only test it with default `subsamplingRate==1`. GBT with `subsamplingRate<1` will convert treePoints to baggedPoints via ```scala convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) ``` in which the orignial weights from `weightCol` will be discarded and all `sampleWeight` are assigned default 1.0; ### Does this PR introduce any user-facing change? No ### How was this patch tested? updated testsuites Closes #27070 from zhengruifeng/gbt_sampling. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
What changes were proposed in this pull request?
add weight support for GBTs by sampling data before passing it to trees and then passing weights to trees
in summary:
1, add setters of
minWeightFractionPerNode
&weightCol
2, update input types in private methods from
RDD[LabeledPoint]
toRDD[Instance]
:DecisionTreeRegressor.train
,GradientBoostedTrees.run
,GradientBoostedTrees.runWithValidation
,GradientBoostedTrees.computeInitialPredictionAndError
,GradientBoostedTrees.computeError
,GradientBoostedTrees.evaluateEachIteration
,GradientBoostedTrees.boost
,GradientBoostedTrees.updatePredictionError
3, add new private method
GradientBoostedTrees.computeError(data, predError)
to compute average error, since originalpredError.values.mean()
do not take weights into account.4, add new tests
Why are the changes needed?
GBTs should support sample weights like other algs
Does this PR introduce any user-facing change?
yes, new setters are added
How was this patch tested?
existing & added testsuites