From e981396be53e9a00552a137e9a47c52536922dbb Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Mar 2014 19:00:08 -0700 Subject: [PATCH] use axpy in Updater --- .../mllib/optimization/GradientDescent.scala | 17 ++++++++++++----- .../spark/mllib/optimization/Updater.scala | 14 +++++++++----- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 41752142247fc..8131925cfc87d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer +import breeze.linalg.{Vector => BV, DenseVector => BDV} + import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} @@ -157,11 +159,16 @@ object GradientDescent extends Logging { for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map { - case (y, features) => - val (grad, loss) = gradient.compute(features, y, weights) - (grad.toBreeze, loss) - }.reduce((a, b) => (a._1 += b._1, a._2 + b._2)) + val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) + .aggregate((BDV.zeros[Double](weights.size), 0.0))( + seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => + val (g, l) = gradient.compute(features, label, weights) + (grad += g.toBreeze, loss + l) + }, + combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + (grad1 += grad2, loss1 + loss2) + } + ) /** * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 2766c8dbb42a0..3b7754cd7ac28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization import scala.math._ -import breeze.linalg.{norm => brzNorm} +import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} import org.apache.spark.mllib.linalg.{Vectors, Vector} @@ -70,7 +70,9 @@ class SimpleUpdater extends Updater { iter: Int, regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) } } @@ -102,7 +104,8 @@ class L1Updater extends Updater { regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) // Take gradient step - val brzWeights = weightsOld.toBreeze - gradient.toBreeze * thisIterStepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize var i = 0 @@ -133,8 +136,9 @@ class SquaredL2Updater extends Updater { // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient val thisIterStepSize = stepSize / math.sqrt(iter) - val brzWeights = weightsOld.toBreeze * (1.0 - thisIterStepSize * regParam) - - (gradient.toBreeze * thisIterStepSize) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzWeights :*= (1.0 - thisIterStepSize * regParam) + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) val norm = brzNorm(brzWeights, 2.0) (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)