Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16957][MLlib] Use midpoints for split values. #17556

Conversation

facaiy
Copy link
Contributor

@facaiy facaiy commented Apr 7, 2017

What changes were proposed in this pull request?

Use midpoints for split values now, and maybe later to make it weighted.

How was this patch tested?

  • add unit test.
  • revise Split's unit test.

@@ -1009,10 +1009,24 @@ private[spark] object RandomForest extends Logging {
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray

def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: cru -> cur? or current?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")

val splits = if (featureSamples.isEmpty) {
val splits: Array[Double] = if (featureSamples.isEmpty) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code block is too long and has 4 exits. Emphasizing its type perhaps is better to be understand, though splits is implied by return type.

def weightedMean(pre: (Double, Int), cru: (Double, Int)): Double = {
val (preValue, preCount) = pre
val (curValue, curCount) = cru
(preValue * preCount + curValue * curCount) / (preCount + curCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm probably over-thinking this, but do we have a possible overflow issue in the denominator? like if both are near Int.MaxValue. One could be converted .toDouble just to make sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. fixed.

} else if (possibleSplits <= numSplits) {
valueCounts
.sliding(2)
.map{x => weightedMean(x(0), x(1))}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use () instead of {}
There are more efficient ways of writing this but not as compact. I think it's OK unless someone suggests this is performance critical here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.
Do you mean use scanLeft? It's a little complicate and obscure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No not scanLeft, just manually building the result array and iterating because it's already known ahead of time how big it is.

@srowen
Copy link
Member

srowen commented Apr 7, 2017

It seems OK to me but @sethah or @jkbradley might be good as a second set of eyes. It does slightly alter behavior, but, it does seem like something that should work better in general.

@SparkQA
Copy link

SparkQA commented Apr 9, 2017

Test build #3652 has started for PR 17556 at commit 9ca5750.

@facaiy
Copy link
Contributor Author

facaiy commented Apr 10, 2017

is there something wrong with spark CI?

@facaiy
Copy link
Contributor Author

facaiy commented Apr 10, 2017

Test Result (1 failure / +1)
    org.apache.spark.storage.TopologyAwareBlockReplicationPolicyBehavior.Peers in 2 racks

Does anyone know what is this?

@srowen
Copy link
Member

srowen commented Apr 10, 2017

Just a flaky test. Can't be related

@SparkQA
Copy link

SparkQA commented Apr 10, 2017

Test build #3654 has finished for PR 17556 at commit 9ca5750.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 10, 2017

Test build #3655 has finished for PR 17556 at commit 9ca5750.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@facaiy
Copy link
Contributor Author

facaiy commented Apr 11, 2017

@srowen Hi, I forget unit tests in python and R. Where can I find document about creating develop environment? thanks.

@srowen
Copy link
Member

srowen commented Apr 11, 2017

@facaiy
Copy link
Contributor Author

facaiy commented Apr 13, 2017

I have ran all unit test case of MLlib in Python. However, I am not familiar with R, and I don't want waste too much time on deploying R's environment.

Could CI retest the pr? We can check if some unit tests are still broken. thanks.

@SparkQA
Copy link

SparkQA commented Apr 13, 2017

Test build #3662 has finished for PR 17556 at commit b74702a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen
Copy link
Member

srowen commented Apr 13, 2017

It's looking good, and the R tests pass. I'll also ask @mengxr or maybe @dbtsai if they have any concerns about this change?

@facaiy
Copy link
Contributor Author

facaiy commented Apr 13, 2017

many thanks, @srowen

val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(1.0, 2.0))
assert(splits === Array(1.8, 2.2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's clearer IMO to do:

assert(splits === Array((2 * 8 + 1 * 2) / (8 + 2), (2 * 8 + 3 * 2) / (8 + 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(2.0, 3.0))
assert(splits === Array(2.0625, 3.5))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@sethah
Copy link
Contributor

sethah commented Apr 13, 2017

If we are attempting to match R GBM, it would be great to show, at least on the PR, that we get the same results.

)
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(0.5))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this block, would you mind adding another test that exercises the possibleSplits > numSplits code path? It actually does get called below, but those tests are for other things and I think it's better to make it explicit what we are testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new case.

@sethah
Copy link
Contributor

sethah commented Apr 13, 2017

Seems like a reasonable change. Just left some minor comments.

@sethah
Copy link
Contributor

sethah commented Apr 28, 2017

I don't mind the weighted midpoints. However, if for a continuous feature we find that many points have the exact same value, we are assuming we may find data points in the test set that are close to but not these same values. But since our train data was clustered at these particular values, perhaps it's not a good assumption. I could live with either method, but maybe a slight preference to match the other libraries.

@srowen
Copy link
Member

srowen commented Apr 28, 2017

@sethah what's the issue there ... train/test ought to be from the same distribution, in theory. The empirical distribution of the test data will of course be a little different, but what is the issue with that w.r.t. this change? From a theoretical perspective, picking the midpoint seems more justified than picking an endpoint, and a weighted mean moreso than a midpoint.

@srowen
Copy link
Member

srowen commented Apr 28, 2017

Ah OK I should think about this more first. Say you have a continuous predictor x and binary output y. Say the optimal split is found to be between 0.1 and 0.2, with 1 observation of 0.1 and 99 of 0.2. Right now the algorithm would pick a split value of 0.2; it certainly can't be > 0.2 or < 0.1 but it's highly unlikely that 0.1 or 0.2 are the actual optimal split value.

A weighted mean says the best split is at 0.199, really. It makes sense if you're attempting to make sure that P(0.1 <= x < 0.199) ~= P(0.199 <= x <= 0.2) -- about half the cases in this critical range fall above and below the split. But really the goal is to find x such that P(y=1 | x) is about 0.5. It's not the same thing but it's also not knowable from the training data.

But 0.15 isn't obviously better either. It would mean that, probably, almost all test values in this critical range are classified as positive, not about half.

@facaiy
Copy link
Contributor Author

facaiy commented Apr 29, 2017

For a (train) sample of continuous series, say {x0, x1, x2, x3, ..., x100}. Now spark selects quantile as split point.

Suppose 10-quantiles is used, and x2 is 1st quantile, and x10 is 2nd quantile. It's believed that P(x < x2) ~= P(x2 < x < x10). However, x2 is not perfect. As the data is continuous, there exits one point z is the real point who satisfy P(x < z) == P(z < x < x10).

And it's reasonable that averaged midpoint between x2 and x3 is more appropriate, in my option.

@facaiy
Copy link
Contributor Author

facaiy commented Apr 29, 2017

By the way, it's safe to use mean value as it is match the other libraries. If requested, I'd like to modify the PR.

@srowen
Copy link
Member

srowen commented Apr 29, 2017

The bucketing is not trying to to bucket into buckets of equal P(x). It's a condition on P(y | x). That said the right point isn't knowable from the training data, and splitting to balance P(x) on either side of the split within the bucket is perhaps the next-most principled thing to do.

To reach a conclusion though: if we have slightly more net preference for a simple average, we could merge that change for now and decide later to make it weighted.

@facaiy
Copy link
Contributor Author

facaiy commented Apr 30, 2017

OK, weight has been removed when calculating.

@facaiy facaiy changed the title [SPARK-16957][MLlib] Use weighted midpoints for split values. [SPARK-16957][MLlib] Use midpoints for split values. Apr 30, 2017
// if possible splits is not enough or just enough, just return all possible splits
// perhaps weighted mean is better in the future, see SPARK-16957 and Github PR 17556.
def mean(pre: (Double, Int), cur: (Double, Int)): Double = {
val (preValue, preCount) = pre
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth factoring a method for this? you could just write (preValue, _) = here, but, just dereferncing ._1 isn't so bad, and then, wondering if it saves much to make a method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should get rid of this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change looks OK even as-is

// if possible splits is not enough or just enough, just return all possible splits
val splits = for {
i <- 0 until possibleSplits
} yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Is it more efficient than sliding?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Maybe even just (0 until possibleSplits).map(...).toArray which is probably about the same thing anyway. You might write / 2.0 to be clear it's floating point division

// if possible splits is not enough or just enough, just return all possible splits
val splits = for {
i <- 0 until possibleSplits
} yield (valueCounts(i)._1 + valueCounts(i + 1)._1) / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Maybe even just (0 until possibleSplits).map(...).toArray which is probably about the same thing anyway. You might write / 2.0 to be clear it's floating point division

val pre = valueCounts(index - 1)
val cur = valueCounts(index)
// perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556.
splitsBuilder += (pre._1 + cur._1) / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meh, could likewise be one line like above. No big deal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! revised.

@facaiy
Copy link
Contributor Author

facaiy commented May 2, 2017

How about testing the pr, @SparkQA

@SparkQA
Copy link

SparkQA commented May 2, 2017

Test build #3682 has finished for PR 17556 at commit 92df1c8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

@sethah sethah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small nit. Otherwise, LGTM. Thanks @facaiy!

@@ -1037,7 +1042,8 @@ private[spark] object RandomForest extends Logging {
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (previousGap < currentGap) {
splitsBuilder += valueCounts(index - 1)._1
// perhaps weighted mean will be used later, see SPARK-16957 and Github PR 17556.
Copy link
Contributor

@sethah sethah May 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments like these tend to just get left around and sit there forever. Unless we file a new JIRA that intends to decide on future behavior, I would like to remove this comment altogether. I'd prefer to just remove it and not create a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed. Thanks for your help! @sethah @srowen

@srowen
Copy link
Member

srowen commented May 3, 2017

Merged to master. I'm not against putting it into 2.2, but I'm conscious we already even had an RC

@asfgit asfgit closed this in 7f96f2d May 3, 2017
@facaiy facaiy deleted the ENH/decision_tree_overflow_and_precision_in_aggregation branch May 3, 2017 12:58
@sethah
Copy link
Contributor

sethah commented May 3, 2017

Thanks @srowen!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants