-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-16957][MLlib] Use midpoints for split values. #17556
[SPARK-16957][MLlib] Use midpoints for split values. #17556
Conversation
@@ -1009,10 +1009,24 @@ private[spark] object RandomForest extends Logging { | |||
// sort distinct values | |||
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray | |||
|
|||
def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: cru -> cur? or current?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
@@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging { | |||
require(metadata.isContinuous(featureIndex), | |||
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") | |||
|
|||
val splits = if (featureSamples.isEmpty) { | |||
val splits: Array[Double] = if (featureSamples.isEmpty) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code block is too long and has 4 exits. Emphasizing its type perhaps is better to be understand, though splits
is implied by return type.
def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = { | ||
val (preValue, preCount) = pre | ||
val (curValue, curCount) = cru | ||
(preValue * preCount + curValue * curCount) / (preCount + curCount) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably over-thinking this, but do we have a possible overflow issue in the denominator? like if both are near Int.MaxValue
. One could be converted .toDouble
just to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. fixed.
} else if (possibleSplits <= numSplits) { | ||
valueCounts | ||
.sliding(2) | ||
.map{x => weightedMean(x(0), x(1))} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: use () instead of {}
There are more efficient ways of writing this but not as compact. I think it's OK unless someone suggests this is performance critical here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
Do you mean use scanLeft
? It's a little complicate and obscure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No not scanLeft, just manually building the result array and iterating because it's already known ahead of time how big it is.
It seems OK to me but @sethah or @jkbradley might be good as a second set of eyes. It does slightly alter behavior, but, it does seem like something that should work better in general. |
Test build #3652 has started for PR 17556 at commit |
is there something wrong with spark CI? |
Does anyone know what is this? |
Just a flaky test. Can't be related |
Test build #3654 has finished for PR 17556 at commit
|
Test build #3655 has finished for PR 17556 at commit
|
@srowen Hi, I forget unit tests in python and R. Where can I find document about creating develop environment? thanks. |
I have ran all unit test case of MLlib in Python. However, I am not familiar with R, and I don't want waste too much time on deploying R's environment. Could CI retest the pr? We can check if some unit tests are still broken. thanks. |
Test build #3662 has finished for PR 17556 at commit
|
many thanks, @srowen |
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
assert(splits === Array(1.0, 2.0)) | ||
assert(splits === Array(1.8, 2.2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's clearer IMO to do:
assert(splits === Array((2 * 8 + 1 * 2) / (8 + 2), (2 * 8 + 3 * 2) / (8 + 2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
assert(splits === Array(2.0, 3.0)) | ||
assert(splits === Array(2.0625, 3.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
If we are attempting to match R GBM, it would be great to show, at least on the PR, that we get the same results. |
) | ||
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) | ||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) | ||
assert(splits === Array(0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this block, would you mind adding another test that exercises the possibleSplits > numSplits
code path? It actually does get called below, but those tests are for other things and I think it's better to make it explicit what we are testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add new case.
Seems like a reasonable change. Just left some minor comments. |
I don't mind the weighted midpoints. However, if for a continuous feature we find that many points have the exact same value, we are assuming we may find data points in the test set that are close to but not these same values. But since our train data was clustered at these particular values, perhaps it's not a good assumption. I could live with either method, but maybe a slight preference to match the other libraries. |
@sethah what's the issue there ... train/test ought to be from the same distribution, in theory. The empirical distribution of the test data will of course be a little different, but what is the issue with that w.r.t. this change? From a theoretical perspective, picking the midpoint seems more justified than picking an endpoint, and a weighted mean moreso than a midpoint. |
Ah OK I should think about this more first. Say you have a continuous predictor x and binary output y. Say the optimal split is found to be between 0.1 and 0.2, with 1 observation of 0.1 and 99 of 0.2. Right now the algorithm would pick a split value of 0.2; it certainly can't be > 0.2 or < 0.1 but it's highly unlikely that 0.1 or 0.2 are the actual optimal split value. A weighted mean says the best split is at 0.199, really. It makes sense if you're attempting to make sure that P(0.1 <= x < 0.199) ~= P(0.199 <= x <= 0.2) -- about half the cases in this critical range fall above and below the split. But really the goal is to find x such that P(y=1 | x) is about 0.5. It's not the same thing but it's also not knowable from the training data. But 0.15 isn't obviously better either. It would mean that, probably, almost all test values in this critical range are classified as positive, not about half. |
For a (train) sample of continuous series, say {x0, x1, x2, x3, ..., x100}. Now spark selects quantile as split point. Suppose 10-quantiles is used, and x2 is 1st quantile, and x10 is 2nd quantile. It's believed that P(x < x2) ~= P(x2 < x < x10). However, x2 is not perfect. As the data is continuous, there exits one point z is the real point who satisfy P(x < z) == P(z < x < x10). And it's reasonable that averaged midpoint between x2 and x3 is more appropriate, in my option. |
By the way, it's safe to use mean value as it is match the other libraries. If requested, I'd like to modify the PR. |
The bucketing is not trying to to bucket into buckets of equal P(x). It's a condition on P(y | x). That said the right point isn't knowable from the training data, and splitting to balance P(x) on either side of the split within the bucket is perhaps the next-most principled thing to do. To reach a conclusion though: if we have slightly more net preference for a simple average, we could merge that change for now and decide later to make it weighted. |
OK, weight has been removed when calculating. |
// if possible splits is not enough or just enough, just return all possible splits | ||
// perhaps weighted mean is better in the future, see SPARK-16957 and Github PR 17556. | ||
def mean(pre: (Double, Int), cur: (Double, Int)): Double = { | ||
val (preValue, preCount) = pre |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth factoring a method for this? you could just write (preValue, _) =
here, but, just dereferncing ._1
isn't so bad, and then, wondering if it saves much to make a method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we should get rid of this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change looks OK even as-is
// if possible splits is not enough or just enough, just return all possible splits | ||
val splits = for { | ||
i <- 0 until possibleSplits | ||
} yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen Is it more efficient than sliding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Maybe even just (0 until possibleSplits).map(...).toArray
which is probably about the same thing anyway. You might write / 2.0
to be clear it's floating point division
// if possible splits is not enough or just enough, just return all possible splits | ||
val splits = for { | ||
i <- 0 until possibleSplits | ||
} yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Maybe even just (0 until possibleSplits).map(...).toArray
which is probably about the same thing anyway. You might write / 2.0
to be clear it's floating point division
val pre = valueCounts(index - 1) | ||
val cur = valueCounts(index) | ||
// perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556. | ||
splitsBuilder += (pre._1 + cur._1) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meh, could likewise be one line like above. No big deal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! revised.
How about testing the pr, @SparkQA |
Test build #3682 has finished for PR 17556 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small nit. Otherwise, LGTM. Thanks @facaiy!
@@ -1037,7 +1042,8 @@ private[spark] object RandomForest extends Logging { | |||
// makes the gap between currentCount and targetCount smaller, | |||
// previous value is a split threshold. | |||
if (previousGap < currentGap) { | |||
splitsBuilder += valueCounts(index - 1)._1 | |||
// perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments like these tend to just get left around and sit there forever. Unless we file a new JIRA that intends to decide on future behavior, I would like to remove this comment altogether. I'd prefer to just remove it and not create a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merged to master. I'm not against putting it into 2.2, but I'm conscious we already even had an RC |
Thanks @srowen! |
What changes were proposed in this pull request?
Use midpoints for split values now, and maybe later to make it weighted.
How was this patch tested?