Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9612][ML][FOLLOWUP] fix GBT support weights if subsamplingRate<1 #27070

Closed
wants to merge 4 commits into from

Conversation

zhengruifeng
Copy link
Contributor

@zhengruifeng zhengruifeng commented Jan 2, 2020

What changes were proposed in this pull request?

1, fix BaggedPoint.convertToBaggedRDD when subsamplingRate < 1.0
2, reorg RandomForest.runWithMetadata btw

Why are the changes needed?

In GBT, Instance weights will be discarded if subsamplingRate<1

1, baggedPoint: BaggedPoint[TreePoint] is used in the tree growth to find best split;
2, BaggedPoint[TreePoint] contains two weights:

class BaggedPoint[Datum](val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double = 1.0)
class TreePoint(val label: Double, val binnedFeatures: Array[Int], val weight: Double)

3, only the var sampleWeight in BaggedPoint is used, the var weight in TreePoint is never used in finding splits;
4, The method BaggedPoint.convertToBaggedRDD was changed in #21632, it was only for decisiontree, so only the following code path was changed;

if (numSubsamples == 1 && subsamplingRate == 1.0) {
        convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
      }

5, In #25926, I made GBT support weights, but only test it with default subsamplingRate==1.
GBT with subsamplingRate<1 will convert treePoints to baggedPoints via

convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)

in which the orignial weights from weightCol will be discarded and all sampleWeight are assigned default 1.0;

Does this PR introduce any user-facing change?

No

How was this patch tested?

updated testsuites

init
@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Jan 2, 2020

ping @imatiach-msft
Would you mind help checking the logic in BaggedPoint.convertToBaggedRDD? Thanks!

@zhengruifeng zhengruifeng added the ML label Jan 2, 2020
@SparkQA
Copy link

SparkQA commented Jan 2, 2020

Test build #116016 has finished for PR 27070 at commit dac8f69.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 2, 2020

Test build #116025 has finished for PR 27070 at commit a41bcba.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me

val retaggedInput = input.retag(classOf[Instance])
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
}

/**
* Train a random forest.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update doc, eg:

Train a random forest with metadata.

also add description for metadata param

@@ -91,8 +91,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
// should ignore weight function for now
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
assert(baggedRDD.collect().forall(_.sampleWeight === 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just trying to understand, why did the sample weight change in this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this testsuite meet conditions: withReplacement=false, numSubsamples!=1,
it will call the modified convertToBaggedRDDSamplingWithoutReplacement,

and the extractSampleWeight here is (_: LabeledPoint) => 2.0, so output baggedPoints will have sampleWeight==2.0

Copy link
Contributor

@imatiach-msft imatiach-msft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@@ -577,7 +592,7 @@ private[spark] object RandomForest extends Logging with Serializable {

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDEA editor always shows warnings on the two lines, change them to avoid warnings.

@SparkQA
Copy link

SparkQA commented Jan 3, 2020

Test build #116063 has finished for PR 27070 at commit e3d9200.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 3, 2020

Test build #116072 has finished for PR 27070 at commit 53750d4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

@imatiach-msft Thanks very much for your reviewing!
BTW, since param weightCol in RF has been already exposed to end users in both scala and py sides, do you have any plan to support weights in RF?
I am afraid the code will be frozen soon.

@zhengruifeng zhengruifeng deleted the gbt_sampling branch January 6, 2020 02:06
@zhengruifeng
Copy link
Contributor Author

Merged to master! Thanks all for reviewing!
@imatiach-msft If you do not have spare time to work on RF supporting weights, I can have a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants