Skip to content

Commit

Permalink
[SPARK-30398][ML] PCA/RegressionMetrics/RowMatrix avoid unnecessary c…
Browse files Browse the repository at this point in the history
…omputation

### What changes were proposed in this pull request?
use `.ml.Summarizer` instead of `.mllib.MultivariateOnlineSummarizer` to avoid computation of unused metrics

### Why are the changes needed?
to avoid computation of unused metrics

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #27059 from zhengruifeng/pac_summarizer.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
zhengruifeng authored and srowen committed Jan 4, 2020
1 parent 4a234dd commit c42fbc7
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 345 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{Dataset, Row}
Expand Down Expand Up @@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth)

val (summarizer, labelSummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
(Summarizer.createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.aggregator.LogisticAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
Expand Down Expand Up @@ -501,7 +501,7 @@ class LogisticRegression @Since("1.2.0") (
fitIntercept)

val (summarizer, labelSummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
(Summarizer.createSummarizerBuffer("mean", "std", "count"), new MultiClassSummarizer))(
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -215,7 +215,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val featuresSummarizer = instances.treeAggregate(
createSummarizerBuffer("mean", "std", "count"))(
Summarizer.createSummarizerBuffer("mean", "std", "count"))(
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
depth = $(aggregationDepth)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggreg
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.stat._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.evaluation.RegressionMetrics
Expand Down Expand Up @@ -358,8 +358,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (featuresSummarizer, ySummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "std"),
createSummarizerBuffer("mean", "std", "count")))(
(Summarizer.createSummarizerBuffer("mean", "std"),
Summarizer.createSummarizerBuffer("mean", "std", "count")))(
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
(c._1.add(instance.features, instance.weight),
c._2.add(Vectors.dense(instance.label), instance.weight)),
Expand Down
Loading

0 comments on commit c42fbc7

Please sign in to comment.