Skip to content

Commit

Permalink
Fix so not preparing data twice when calling model selector fit method (
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire authored Mar 28, 2019
1 parent 3e89a43 commit 3aa144a
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ E <: Estimator[_] with OpPipelineStage2[RealNN, OPVector, Prediction]]
protected[op] def findBestEstimator(data: Dataset[_], dag: StagesDAG, persistEveryKStages: Int = 0)
(implicit spark: SparkSession): Unit = {

splitter.foreach(_.preValidationPrepare(data.select(labelColName).toDF()))
val theBestEstimator = validator.validate(modelInfo = modelsUse, dataset = data,
label = in1.name, features = in2.name, dag = Option(dag), splitter = splitter,
stratifyCondition = validator.isClassification
Expand Down Expand Up @@ -146,20 +147,17 @@ E <: Estimator[_] with OpPipelineStage2[RealNN, OPVector, Prediction]]
}
require(!datasetWithID.isEmpty, "Dataset cannot be empty")

val ModelData(trainData, splitterSummary) = splitter match {
case Some(spltr) => spltr.prepare(datasetWithID)
case None => ModelData(datasetWithID, None)
}

val splitterSummary = splitter.flatMap(_.preValidationPrepare(datasetWithID))
val BestEstimator(name, estimator, summary) = bestEstimator.getOrElse {
setInputSchema(dataset.schema).transformSchema(dataset.schema)
val best = validator
.validate(modelInfo = modelsUse, dataset = trainData, label = in1.name, features = in2.name)
.validate(modelInfo = modelsUse, dataset = datasetWithID, label = in1.name, features = in2.name)
bestEstimator = Some(best)
best
}

val bestModel = estimator.fit(trainData).asInstanceOf[M]
val preparedData = splitter.map(_.validationPrepare(datasetWithID)).getOrElse(datasetWithID)
val bestModel = estimator.fit(preparedData).asInstanceOf[M]
val bestEst = bestModel.parent
log.info(s"Selected model : ${bestEst.getClass.getSimpleName}")
log.info(s"With parameters : ${bestEst.extractParamMap()}")
Expand All @@ -168,7 +166,7 @@ E <: Estimator[_] with OpPipelineStage2[RealNN, OPVector, Prediction]]
outputsColNamesMap.foreach { case (pname, pvalue) => bestModel.set(bestModel.getParam(pname), pvalue) }

// get eval results for metadata
val trainingEval = evaluate(bestModel.transform(trainData))
val trainingEval = evaluate(bestModel.transform(preparedData))

val metadataSummary = ModelSelectorSummary(
validationType = ValidationType.fromValidator(validator),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,39 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
}
}


/**
* Split into a training set and a test set and balance the training set
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data to prepare for model training. first column must be the label as a double
* @return balanced training set and a test set
* @param data
* @return Parameters set in examining data
*/
def prepare(data: Dataset[Row]): ModelData = {
override def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary] = {
val negativeData = data.filter(_.getDouble(0) == 0.0).persist()
val positiveData = data.filter(_.getDouble(0) == 1.0).persist()
val negativeCount = negativeData.count()
val positiveCount = positiveData.count()
val seed = getSeed

if (!(isSet(isPositiveSmall) ||
isSet(downSampleFraction) ||
isSet(upSampleFraction) ||
isSet(alreadyBalancedFraction))
) {
estimate(positiveCount = positiveCount, negativeCount = negativeCount, seed = seed)
}
estimate(positiveCount = positiveCount, negativeCount = negativeCount, seed = seed)

summary
}

/**
* Rebalance the training data within the validation step
*
* @param data to prepare for model training. first column must be the label as a double
* @return balanced training set and a test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row] = {

if (summary.isEmpty) throw new RuntimeException("Cannot call prepare until examine has been called")

val negativeData = data.filter(_.getDouble(0) == 0.0).persist()
val positiveData = data.filter(_.getDouble(0) == 1.0).persist()
val seed = getSeed

// If these conditions are met, that means that we have enough information to balance the data : upSample,
// downSample and which class is in minority
Expand Down Expand Up @@ -166,7 +179,7 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
}
}

ModelData(balanced.persist(), summary)
balanced.persist()
}

override def copy(extra: ParamMap): DataBalancer = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,30 +75,42 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with

@transient private lazy val log = LoggerFactory.getLogger(this.getClass)

@transient private[op] var summary: Option[DataCutterSummary] = None

/**
* function to use to prepare the dataset for modeling
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data first column must be the label as a double
* @return Training set test set
* @param data
* @return Parameters set in examining data
*/
def prepare(data: Dataset[Row]): ModelData = {
override def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary] = {

import data.sparkSession.implicits._

val keep: Set[Double] =
if (!isSet(labelsToKeep) || !isSet(labelsToDrop)) {
val labels = data.map(r => r.getDouble(0) -> 1L)
val labelCounts = labels.groupBy(labels.columns(0)).sum(labels.columns(1)).persist()
val (resKeep, resDrop) = estimate(labelCounts)
labelCounts.unpersist()
setLabels(resKeep, resDrop)
resKeep
} else getLabelsToKeep.toSet
val labels = data.map(r => r.getDouble(0) -> 1L)
val labelCounts = labels.groupBy(labels.columns(0)).sum(labels.columns(1)).persist()
val (resKeep, resDrop) = estimate(labelCounts)
labelCounts.unpersist()
setLabels(resKeep, resDrop)

summary = Option(DataCutterSummary(labelsKept = getLabelsToKeep, labelsDropped = getLabelsToDrop))
summary
}

/**
* Removes labels that should not be used in modeling
*
* @param data first column must be the label as a double
* @return Training set test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row] = {
if (summary.isEmpty) throw new RuntimeException("Cannot call prepare until examine has been called")

val keep: Set[Double] = getLabelsToKeep.toSet
val dataUse = data.filter(r => keep.contains(r.getDouble(0)))
val summary = DataCutterSummary(labelsKept = getLabelsToKeep, labelsDropped = getLabelsToDrop)

ModelData(dataUse, Some(summary))
dataUse
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,24 @@ case object DataSplitter {
*/
class DataSplitter(uid: String = UID[DataSplitter]) extends Splitter(uid = uid) {


/**
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Parameters set in examining data
*/
override def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary] = Option(DataSplitterSummary())

/**
* Function to use to prepare the dataset for modeling
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Training set test set
*/
def prepare(data: Dataset[Row]): ModelData = ModelData(data, Some(DataSplitterSummary()))
def validationPrepare(data: Dataset[Row]): Dataset[Row] = data

override def copy(extra: ParamMap): DataSplitter = {
val copy = new DataSplitter(uid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ private[op] trait OpValidator[M <: Model[_], E <: Estimator[_]] extends Serializ
.withColumn(ModelSelectorNames.idColName, monotonically_increasing_id())

val (balancedTrain, balancedTest) = splitter.map(s => (
s.prepare(selectTrain).train,
s.prepare(selectTest).train)
s.validationPrepare(selectTrain),
s.validationPrepare(selectTest))
).getOrElse((selectTrain, selectTest))

(balancedTrain, balancedTest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,6 @@ import scala.util.Try



/**
* Case class for Training & test sets
*
* @param train training set is persisted at construction
* @param summary summary for building metadata
*/
case class ModelData(train: Dataset[Row], summary: Option[SplitterSummary])

/**
* Abstract class that will carry on the creation of training set + test set
Expand All @@ -68,13 +61,23 @@ abstract class Splitter(val uid: String) extends SplitterParams {
}

/**
* Function to use to prepare the dataset for modeling
* Function to use to prepare the dataset for modeling within the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Training set test set
*/
def prepare(data: Dataset[Row]): ModelData
def validationPrepare(data: Dataset[Row]): Dataset[Row]


/**
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Parameters set in examining data
*/
def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary]

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ package com.salesforce.op.stages.impl.tuning
import com.salesforce.op.test.TestSparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.random.RandomRDDs
import org.apache.spark.sql.{Dataset, Row}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertion, FlatSpec}
Expand Down Expand Up @@ -105,57 +106,67 @@ class DataBalancerTest extends FlatSpec with TestSparkContext with SplitterSumma
val fraction = 0.4
val maxSize = 2000
val balancer = DataBalancer(sampleFraction = fraction, maxTrainingSample = maxSize, seed = 11L)
val res1 = balancer.prepare(data)
val s1 = balancer.preValidationPrepare(data)
val res1 = balancer.validationPrepare(data)
val (downSample, upSample) = balancer.getProportions(smallCount, bigCount, fraction, maxSize)

balancer.getUpSampleFraction shouldBe upSample
balancer.getDownSampleFraction shouldBe downSample
balancer.getIsPositiveSmall shouldBe false
checkRecurringPrepare(balancer, res1, DataBalancerSummary(800, 200, 0.4, 2.0, 0.75))
checkRecurringPrepare(balancer, res1, s1, DataBalancerSummary(800, 200, 0.4, 2.0, 0.75))
}

it should "throw an error if you try to prepare before examining" in {
val balancer = DataBalancer(sampleFraction = 0.1, maxTrainingSample = 2000, seed = 11L)
intercept[RuntimeException](balancer.validationPrepare(data)).getMessage shouldBe
"Cannot call prepare until examine has been called"
}

it should "remember that data is already balanced" in {
val fraction = 0.01
val maxSize = 20000
val balancer = DataBalancer(sampleFraction = fraction, maxTrainingSample = maxSize, seed = 11L)
val res1 = balancer.prepare(data)
val s1 = balancer.preValidationPrepare(data)
val res1 = balancer.validationPrepare(data)

balancer.getAlreadyBalancedFraction shouldBe 1.0
checkRecurringPrepare(balancer, res1, DataBalancerSummary(800, 200, 0.01, 0.0, 1.0))
checkRecurringPrepare(balancer, res1, s1, DataBalancerSummary(800, 200, 0.01, 0.0, 1.0))
}

it should "remember that data is already balanced, but needs to be sample because too big" in {
val fraction = 0.01
val maxSize = 100
val balancer = DataBalancer(sampleFraction = fraction, maxTrainingSample = maxSize, seed = 11L)
val res1 = balancer.prepare(data)
val s1 = balancer.preValidationPrepare(data)
val res1 = balancer.validationPrepare(data)

balancer.getAlreadyBalancedFraction shouldBe maxSize.toDouble / (smallCount + bigCount)
checkRecurringPrepare(balancer, res1, DataBalancerSummary(800, 200, 0.01, 0.0, 0.1))
checkRecurringPrepare(balancer, res1, s1, DataBalancerSummary(800, 200, 0.01, 0.0, 0.1))
}

private def checkRecurringPrepare(
balancer: DataBalancer,
previousResult: ModelData,
previousResult: Dataset[Row],
summary: Option[SplitterSummary],
expectedSummary: DataBalancerSummary
): Assertion = {
assertDataBalancerSummary(previousResult.summary) { s =>
assertDataBalancerSummary(summary) { s =>
s shouldBe expectedSummary
balancer.summary shouldBe Some(expectedSummary)
}

// Rerun balancer with set params
withClue("Data balancer should not update the summary") {
val ModelData(train, s) = balancer.prepare(spark.emptyDataFrame)
val train = balancer.validationPrepare(spark.emptyDataFrame)
train.count() shouldBe 0
s shouldBe Some(expectedSummary)
balancer.summary shouldBe Some(expectedSummary)
}

// Rerun balancer again and expect the same data & summary
val res2 = balancer.prepare(data)
res2.train.collect() shouldBe previousResult.train.collect()
res2.summary shouldBe previousResult.summary
val s2 = balancer.preValidationPrepare(data)
val res2 = balancer.validationPrepare(data)
res2.collect() shouldBe previousResult.collect()
s2 shouldBe summary
balancer.summary shouldBe Some(expectedSummary)
}

Expand Down
Loading

0 comments on commit 3aa144a

Please sign in to comment.