Skip to content

Commit

Permalink
Refactored splitter tests (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
tovbinm authored Nov 7, 2018
1 parent 1883b5e commit 71e20a3
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 278 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ E <: Estimator[_] with OpPipelineStage2[RealNN, OPVector, Prediction]]
}
require(!datasetWithID.isEmpty, "Dataset cannot be empty")

val ModelData(trainData, met) = splitter match {
val ModelData(trainData, splitterSummary) = splitter match {
case Some(spltr) => spltr.prepare(datasetWithID)
case None => ModelData(datasetWithID, None)
}
Expand Down Expand Up @@ -174,7 +174,7 @@ E <: Estimator[_] with OpPipelineStage2[RealNN, OPVector, Prediction]]
validationType = ValidationType.fromValidator(validator),
validationParameters = validator.getParams(),
dataPrepParameters = splitter.map(_.extractParamMap().getAsMap()).getOrElse(Map()),
dataPrepResults = met,
dataPrepResults = splitterSummary,
evaluationMetric = validator.evaluator.name,
problemType = ProblemType.fromEvalMetrics(trainingEval),
bestModelUID = estimator.uid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ package com.salesforce.op.stages.impl.tuning
import com.salesforce.op.UID
import com.salesforce.op.stages.impl.selector.ModelSelectorNames
import org.apache.spark.ml.param._
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -123,45 +123,66 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
* @return balanced training set and a test set
*/
def prepare(data: Dataset[Row]): ModelData = {
val negativeData = data.filter(_.getDouble(0) == 0.0).persist()
val positiveData = data.filter(_.getDouble(0) == 1.0).persist()
val negativeCount = negativeData.count()
val positiveCount = positiveData.count()
val seed = getSeed

if (!(isSet(isPositiveSmall) ||
isSet(downSampleFraction) ||
isSet(upSampleFraction) ||
isSet(alreadyBalancedFraction))
) {
estimate(positiveCount = positiveCount, negativeCount = negativeCount, seed = seed)
}

val ds = data.persist()

val Array(negativeData, positiveData) = Array(0.0, 1.0).map(label => ds.filter(_.getDouble(0) == label).persist())
val balancerSeed = getSeed

prepareData(
data = ds,
positiveData = positiveData,
negativeData = negativeData,
seed = balancerSeed
)
// If these conditions are met, that means that we have enough information to balance the data : upSample,
// downSample and which class is in minority
val balanced: DataFrame = {
if (isSet(isPositiveSmall) && isSet(downSampleFraction) && isSet(upSampleFraction)) {
val (down, up) = ($(downSampleFraction), $(upSampleFraction))
log.info(s"Sample fractions: downSample of $down, upSample of $up")
val (smallData, bigData) =
if ($(isPositiveSmall)) (positiveData, negativeData) else (negativeData, positiveData)
rebalance(
smallData = smallData,
upSampleFraction = up,
bigData = bigData,
downSampleFraction = down,
seed = seed
)
} else {
// Data is already balanced, but need to be sampled
val fraction = $(alreadyBalancedFraction)
log.info(s"Data is already balanced, yet it will be sampled by a fraction of $fraction")
sampleBalancedData(
fraction = fraction,
seed = seed,
data = data,
positiveData = positiveData,
negativeData = negativeData
)
}
}

ModelData(balanced.persist(), summary)
}

override def copy(extra: ParamMap): DataBalancer = {
val copy = new DataBalancer(uid)
copyValues(copy, extra)
}



/**
* Estimate if data needs to be balanced or not. If so, computes sample fractions and sets the appropriate params
*
* @param data input data
* @param positiveData data with positives only
* @param negativeData data with negatives only
* @param positiveCount number of positives
* @param negativeCount number of negatives
* @param seed seed
* @return balanced data
*/
private[op] def estimate[T](
data: Dataset[T],
positiveData: Dataset[T],
negativeData: Dataset[T],
seed: Long
): Unit = {
val positiveCount = positiveData.count()
val negativeCount = negativeData.count()
private[op] def estimate[T](positiveCount: Long, negativeCount: Long, seed: Long): Unit = {
val totalCount = positiveCount + negativeCount
val sampleF = getSampleFraction
log.info(s"Data has $positiveCount positive and $negativeCount negative.")
Expand All @@ -180,8 +201,7 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)

if (smallCount.toDouble / totalCount.toDouble >= sampleF) {
log.info(
s"Not resampling data: $smallCount small count and $bigCount big count is greater than" +
s" requested $sampleF"
s"Not resampling data: $smallCount small count and $bigCount big count is greater than requested $sampleF"
)
// if data is too big downsample
val fraction = if (maxTrainSample < totalCount) maxTrainSample / totalCount.toDouble else 1.0
Expand All @@ -202,8 +222,7 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
desiredFraction = sampleF, upSamplingFraction = upSample, downSamplingFraction = downSample))

val (posFraction, negFraction) =
if (positiveCount < negativeCount) (upSample, downSample)
else (downSample, upSample)
if (positiveCount < negativeCount) (upSample, downSample) else (downSample, upSample)

val newPositiveCount = math.rint(positiveCount * posFraction)
val newNegativeCount = math.rint(negativeCount * negFraction)
Expand All @@ -223,43 +242,6 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)

}
}
/**
* Preparing data
*
* @param data input data
* @param positiveData data with positives only
* @param negativeData data with negatives only
* @param seed seed
* @return balanced data
*/
private[op] def prepareData[T](
data: Dataset[T],
positiveData: Dataset[T],
negativeData: Dataset[T],
seed: Long
): ModelData = {

if (!(isSet(isPositiveSmall) || isSet(downSampleFraction) ||
isSet(upSampleFraction) || isSet(alreadyBalancedFraction))) {
estimate(data = data, positiveData = positiveData, negativeData = negativeData, seed = seed)
}

// If these conditions are met, that means that we have enough information to balance the data : upSample,
// downSample and which class is in minority
if (isSet(isPositiveSmall) && isSet(downSampleFraction) && isSet(upSampleFraction)) {
val (down, up) = ($(downSampleFraction), $(upSampleFraction))
log.info(s"Sample fractions: downSample of $down, upSample of $up")
val (smallData, bigData) = if ($(isPositiveSmall)) (positiveData, negativeData) else (negativeData, positiveData)
new ModelData(rebalance(smallData, up, bigData, down, seed).toDF().persist(), summary)
} else { // Data is already balanced, but need to be sampled
val fraction = $(alreadyBalancedFraction)
log.info(s"Data is already balanced, yet it will be sampled by a fraction of $fraction")
val balanced = sampleBalancedData(fraction = fraction, seed = seed,
data = data, positiveData = positiveData, negativeData = negativeData).toDF()
new ModelData(balanced.persist(), summary)
}
}


/**
*
Expand All @@ -284,7 +266,6 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
case 1.0 => smallData // if upSample == 1.0, no need to upSample
case u => smallData.sample(withReplacement = false, u, seed = seed) // downsample instead
}

smallDataTrain.union(bigDataTrain)
}

Expand Down Expand Up @@ -330,9 +311,7 @@ trait DataBalancerParams extends Params {
)
setDefault(sampleFraction, SplitterParamsDefault.SampleFractionDefault)

def setSampleFraction(value: Double): this.type = {
set(sampleFraction, value)
}
def setSampleFraction(value: Double): this.type = set(sampleFraction, value)

def getSampleFraction: Double = $(sampleFraction)

Expand All @@ -350,9 +329,7 @@ trait DataBalancerParams extends Params {
)
setDefault(maxTrainingSample, SplitterParamsDefault.MaxTrainingSampleDefault)

def setMaxTrainingSample(value: Int): this.type = {
set(maxTrainingSample, value)
}
def setMaxTrainingSample(value: Int): this.type = set(maxTrainingSample, value)

def getMaxTrainingSample: Int = $(maxTrainingSample)

Expand All @@ -366,9 +343,7 @@ trait DataBalancerParams extends Params {
"fraction to sample minority data", ParamValidators.gt(0.0) // it can be a downSample fraction
)

private[op] def setUpSampleFraction(value: Double): this.type = {
set(upSampleFraction, value)
}
private[op] def setUpSampleFraction(value: Double): this.type = set(upSampleFraction, value)

private[op] def getUpSampleFraction: Double = $(upSampleFraction)

Expand All @@ -385,9 +360,7 @@ trait DataBalancerParams extends Params {
)
)

private[op] def setDownSampleFraction(value: Double): this.type = {
set(downSampleFraction, value)
}
private[op] def setDownSampleFraction(value: Double): this.type = set(downSampleFraction, value)

private[op] def getDownSampleFraction: Double = $(downSampleFraction)

Expand All @@ -400,9 +373,7 @@ trait DataBalancerParams extends Params {
private[op] final val isPositiveSmall = new BooleanParam(this, "isPositiveSmall",
"whether or not positive data is in minority")

private[op] def setIsPositiveSmall(value: Boolean): this.type = {
set(isPositiveSmall, value)
}
private[op] def setIsPositiveSmall(value: Boolean): this.type = set(isPositiveSmall, value)

private[op] def getIsPositiveSmall: Boolean = $(isPositiveSmall)

Expand All @@ -417,9 +388,7 @@ trait DataBalancerParams extends Params {
ParamValidators.inRange(lowerBound = 0.0, upperBound = 1.0, lowerInclusive = false, upperInclusive = true)
)

private[op] def setAlreadyBalancedFraction(value: Double): this.type = {
set(alreadyBalancedFraction, value)
}
private[op] def setAlreadyBalancedFraction(value: Double): this.type = set(alreadyBalancedFraction, value)

private[op] def getAlreadyBalancedFraction: Double = $(alreadyBalancedFraction)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with
def prepare(data: Dataset[Row]): ModelData = {
import data.sparkSession.implicits._

val keep =
val keep: Set[Double] =
if (!isSet(labelsToKeep) || !isSet(labelsToDrop)) {
val labels = data.map(r => r.getDouble(0) -> 1L)
val labelCounts = labels.groupBy(labels.columns(0)).sum(labels.columns(1)).persist()
Expand All @@ -96,9 +96,9 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with
} else getLabelsToKeep.toSet

val dataUse = data.filter(r => keep.contains(r.getDouble(0)))
val summary = DataCutterSummary(labelsKept = getLabelsToKeep, labelsDropped = getLabelsToDrop)

val labelsMeta = DataCutterSummary(labelsKept = getLabelsToKeep, labelsDropped = getLabelsToDrop)
new ModelData(dataUse, Option(labelsMeta))
ModelData(dataUse, Some(summary))
}

/**
Expand Down Expand Up @@ -146,9 +146,7 @@ private[impl] trait DataCutterParams extends Params {
)
setDefault(maxLabelCategories, SplitterParamsDefault.MaxLabelCategoriesDefault)

def setMaxLabelCategories(value: Int): this.type = {
set(maxLabelCategories, value)
}
def setMaxLabelCategories(value: Int): this.type = set(maxLabelCategories, value)

def getMaxLabelCategories: Int = $(maxLabelCategories)

Expand All @@ -159,9 +157,7 @@ private[impl] trait DataCutterParams extends Params {
)
setDefault(minLabelFraction, SplitterParamsDefault.MinLabelFractionDefault)

def setMinLabelFraction(value: Double): this.type = {
set(minLabelFraction, value)
}
def setMinLabelFraction(value: Double): this.type = set(minLabelFraction, value)

def getMinLabelFraction: Double = $(minLabelFraction)

Expand All @@ -185,18 +181,18 @@ private[impl] trait DataCutterParams extends Params {
/**
* Summary of results for data cutter
* @param labelsKept labels retained
* @param labelsDropped labels dropped by datacutter
* @param labelsDropped labels dropped by data cutter
*/
case class DataCutterSummary
(
labelsKept: Array[Double],
labelsDropped: Array[Double]
labelsKept: Seq[Double],
labelsDropped: Seq[Double]
) extends SplitterSummary {
override def toMetadata(): Metadata = {
new MetadataBuilder()
.putString(SplitterSummary.ClassName, this.getClass.getName)
.putDoubleArray(ModelSelectorNames.LabelsKept, labelsKept)
.putDoubleArray(ModelSelectorNames.LabelsDropped, labelsDropped)
.putDoubleArray(ModelSelectorNames.LabelsKept, labelsKept.toArray)
.putDoubleArray(ModelSelectorNames.LabelsDropped, labelsDropped.toArray)
.build()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ class DataSplitter(uid: String = UID[DataSplitter]) extends Splitter(uid = uid)
* @param data
* @return Training set test set
*/
def prepare(data: Dataset[Row]): ModelData =
new ModelData(data, Option(DataSplitterSummary()))
def prepare(data: Dataset[Row]): ModelData = ModelData(data, Some(DataSplitterSummary()))

override def copy(extra: ParamMap): DataSplitter = {
val copy = new DataSplitter(uid)
Expand All @@ -78,7 +77,7 @@ class DataSplitter(uid: String = UID[DataSplitter]) extends Splitter(uid = uid)
}

/**
* Empty class because no summary information for a datasplitter
* Empty class because no summary information for a data splitter
*/
case class DataSplitterSummary() extends SplitterSummary {
override def toMetadata(): Metadata = new MetadataBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,14 @@ private[op] class OpCrossValidation[M <: Model[_], E <: Estimator[_]]
* @param splitter used to estimate splitter params prior to cv
* @return Array((TrainRDD, ValidationRDD), Index)
*/
private[op] override def createTrainValidationSplits[T](stratifyCondition: Boolean,
dataset: Dataset[T], label: String, splitter: Option[Splitter] = None): Array[(RDD[Row], RDD[Row])] = {
private[op] override def createTrainValidationSplits[T](
stratifyCondition: Boolean,
dataset: Dataset[T],
label: String,
splitter: Option[Splitter]
): Array[(RDD[Row], RDD[Row])] = {

// TODO : Implement our own kFold method for better performance in a separate PR
// TODO: Implement our own kFold method for better performance in a separate PR

// get param that stores the label column
val labelCol = evaluator.getParam(ValidatorParamDefaults.LabelCol)
Expand All @@ -160,13 +164,12 @@ private[op] class OpCrossValidation[M <: Model[_], E <: Estimator[_]]
}
}


private def stratifyKFolds(rddsByClass: Array[RDD[Row]]): Array[(RDD[Row], RDD[Row])] = {
// Cross Validation's Train/Validation data for each class
val foldsByClass = rddsByClass.map(rdd => MLUtils.kFold(rdd, numFolds, seed)).toSeq

if (foldsByClass.isEmpty) {
throw new RuntimeException("Dataset is too small for CV forlds selected some empty datasets are created")
throw new RuntimeException("Dataset is too small for CV folds selected some empty datasets are created")
}
// Merging Train/Validation data one by one
foldsByClass.reduce[Array[(RDD[Row], RDD[Row])]] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ private[op] class OpTrainValidationSplit[M <: Model[_], E <: Estimator[_]]
* Creates Train Validation Splits For TS
*
* @param stratifyCondition condition to do stratify ts
* @param dataset dataset to split
* @param label name of label in dataset
* @param splitter used to estimate splitter params prior to ts
* @param dataset dataset to split
* @param label name of label in dataset
* @param splitter used to estimate splitter params prior to ts
* @return Array[(Train, Test)]
*/
private[op] override def createTrainValidationSplits[T](
stratifyCondition: Boolean,
dataset: Dataset[T],
label: String,
splitter: Option[Splitter] = None
splitter: Option[Splitter]
): Array[(RDD[Row], RDD[Row])] = {

// get param that stores the label column
Expand Down
Loading

0 comments on commit 71e20a3

Please sign in to comment.