-
Notifications
You must be signed in to change notification settings - Fork 398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression training limit #413
Changes from all commits
e4b8a92
2170254
dff09b9
14c6b42
722341b
34d5bf1
8e2778d
433d483
0932810
2b02f8a
0521a37
962e06f
80a80d5
0ab4d9a
ef4327c
8ca0e78
8e67f27
009706d
cfbe22f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,6 +129,24 @@ trait SplitterParams extends Params { | |
def setReserveTestFraction(value: Double): this.type = set(reserveTestFraction, value) | ||
def getReserveTestFraction: Double = $(reserveTestFraction) | ||
|
||
/** | ||
* Maximum size of dataset want to train on. | ||
* Value should be > 0. | ||
* Default is 1000000. | ||
* | ||
* @group param | ||
*/ | ||
final val maxTrainingSample = new IntParam(this, "maxTrainingSample", | ||
"maximum size of dataset want to train on", ParamValidators.inRange( | ||
lowerBound = 0, upperBound = 1 << 30, lowerInclusive = false, upperInclusive = true | ||
) | ||
) | ||
setDefault(maxTrainingSample, SplitterParamsDefault.MaxTrainingSampleDefault) | ||
|
||
def setMaxTrainingSample(value: Int): this.type = set(maxTrainingSample, value) | ||
|
||
def getMaxTrainingSample: Int = $(maxTrainingSample) | ||
|
||
final val labelColumnName = new Param[String](this, "labelColumnName", | ||
"label column name, column 0 if not specified") | ||
private[op] def getLabelColumnName = $(labelColumnName) | ||
|
@@ -143,6 +161,7 @@ object SplitterParamsDefault { | |
val MaxTrainingSampleDefault = 1E6.toInt | ||
val MaxLabelCategoriesDefault = 100 | ||
val MinLabelFractionDefault = 0.0 | ||
val DownSampleFractionDefault = 1.0 | ||
} | ||
|
||
trait SplitterSummary extends MetadataLike | ||
|
@@ -152,7 +171,10 @@ private[op] object SplitterSummary { | |
|
||
def fromMetadata(metadata: Metadata): Try[SplitterSummary] = Try { | ||
metadata.getString(ClassName) match { | ||
case s if s == classOf[DataSplitterSummary].getName => DataSplitterSummary() | ||
case s if s == classOf[DataSplitterSummary].getName => DataSplitterSummary( | ||
preSplitterDataCount = metadata.getLong(ModelSelectorNames.PreSplitterDataCount), | ||
downSamplingFraction = metadata.getDouble(ModelSelectorNames.DownSample) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add the downsample fraction to the datacutter params as well... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added downsample fraction into the datacutter params as part of the multi class classification training limit changes. I'll create the PR for it today. |
||
case s if s == classOf[DataBalancerSummary].getName => DataBalancerSummary( | ||
positiveLabels = metadata.getLong(ModelSelectorNames.Positive), | ||
negativeLabels = metadata.getLong(ModelSelectorNames.Negative), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSumma | |
|
||
val seed = 1234L | ||
val dataCount = 1000 | ||
val trainingLimitDefault = 1E6.toLong | ||
|
||
val data = | ||
RandomRDDs.normalVectorRDD(sc, 1000, 3, seed = seed) | ||
|
@@ -56,6 +57,37 @@ class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSumma | |
train.count() shouldBe dataCount | ||
} | ||
|
||
it should "down-sample when the data count is above the default training limit" in { | ||
val numRows = trainingLimitDefault * 2 | ||
val data = | ||
RandomRDDs.normalVectorRDD(sc, numRows, 3, seed = seed) | ||
.map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF() | ||
dataSplitter.preValidationPrepare(data) | ||
|
||
val dataBalanced = dataSplitter.validationPrepare(data) | ||
// validationPrepare calls the data sample method that samples the data to a target ratio but there is an epsilon | ||
// to how precise this function is which is why we need to check around that epsilon | ||
val samplingErrorEpsilon = (0.1 * trainingLimitDefault).toLong | ||
|
||
dataBalanced.count() shouldBe trainingLimitDefault +- samplingErrorEpsilon | ||
} | ||
|
||
it should "set and get all data splitter params" in { | ||
val maxRows = dataCount / 2 | ||
val downSampleFraction = maxRows / dataCount.toDouble | ||
|
||
val dataSplitter = DataSplitter() | ||
.setReserveTestFraction(0.0) | ||
.setSeed(seed) | ||
.setMaxTrainingSample(maxRows) | ||
.setDownSampleFraction(downSampleFraction) | ||
|
||
dataSplitter.getReserveTestFraction shouldBe 0.0 | ||
dataSplitter.getDownSampleFraction shouldBe downSampleFraction | ||
dataSplitter.getSeed shouldBe seed | ||
dataSplitter.getMaxTrainingSample shouldBe maxRows | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's probably worth checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. made changes here 433d483 |
||
it should "split the data in the appropriate proportion - 0.2" in { | ||
val (train, test) = dataSplitter.setReserveTestFraction(0.2).split(data) | ||
math.abs(test.count() - 200) < 30 shouldBe true | ||
|
@@ -69,10 +101,13 @@ class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSumma | |
} | ||
|
||
it should "keep the data unchanged when prepare is called" in { | ||
val dataCount = data.count() | ||
val summary = dataSplitter.preValidationPrepare(data) | ||
val train = dataSplitter.validationPrepare(data) | ||
val sampleF = trainingLimitDefault / dataCount.toDouble | ||
val downSampleFraction = math.min(sampleF, 1.0) | ||
train.collect().zip(data.collect()).foreach { case (a, b) => a shouldBe b } | ||
assertDataSplitterSummary(summary.summaryOpt) { s => s shouldBe DataSplitterSummary() } | ||
assertDataSplitterSummary(summary.summaryOpt) { s => s shouldBe DataSplitterSummary(dataCount, downSampleFraction) } | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this also exposed for the datacutter class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I added it to SplitterParams which datacutter has access to - e4b8a92. So that I can use the same set/get functions across DataBalancer, DataCutter and DataSplitter.