Skip to content

Commit

Permalink
[SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed…
Browse files Browse the repository at this point in the history
… memory tradeoff for TrainValidationSplit

## What changes were proposed in this pull request?

Avoid holding all models in memory for `TrainValidationSplit`.

## How was this patch tested?

Existing tests.

Author: Bago Amirbekian <bago@databricks.com>

Closes #20143 from MrBago/trainValidMemoryFix.
  • Loading branch information
MrBago authored and jkbradley committed Jan 5, 2018
1 parent 52fc5c1 commit cf0aa65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
} (executionContext)
}

// Wait for metrics to be calculated before unpersisting validation dataset
// Wait for metrics to be calculated
val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))

// Unpersist training & validation set once all metrics have been produced
trainingDataset.unpersist()
validationDataset.unpersist()
foldMetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

// Fit models in a Future for training in parallel
logDebug(s"Train split with multiple sets of parameters.")
val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Model[_]] {
val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]

if (collectSubModelsParam) {
subModels.get(paramIndex) = model
}
model
} (executionContext)
}

// Unpersist training data only when all models have trained
Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
.onComplete { _ => trainingDataset.unpersist() } (executionContext)

// Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
modelFuture.map { model =>
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
logDebug(s"Got metric $metric for model trained with $paramMap.")
Expand All @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
// Wait for all metrics to be calculated
val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))

// Unpersist validation set once all metrics have been produced
// Unpersist training & validation set once all metrics have been produced
trainingDataset.unpersist()
validationDataset.unpersist()

logInfo(s"Train validation split metrics: ${metrics.toSeq}")
Expand Down

0 comments on commit cf0aa65

Please sign in to comment.