diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 4dfd1f0ab8134..8c4c5db5258d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] ( require(mean.size == variance.size) - private lazy val factor: BDV[Double] = { - val f = BDV.zeros[Double](variance.size) + private lazy val factor: Array[Double] = { + val f = Array.ofDim[Double](variance.size) var i = 0 while (i < f.size) { f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 @@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] ( f } + // Since `shift` will be only used in `withMean` branch, we have it as + // `lazy val` so it will be evaluated in that branch. Note that we don't + // want to create this array multiple times in `transform` function. + private lazy val shift: Array[Double] = mean.toArray + /** * Applies standardization transformation on a vector. * @@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] ( override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { - vector.toBreeze match { - case dv: BDV[Double] => - val output = vector.toBreeze.copy - var i = 0 - while (i < output.length) { - output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) - i += 1 + // By default, Scala generates Java methods for member variables. So every time when + // the member variables are accessed, `invokespecial` will be called which is expensive. + // This can be avoid by having a local reference of `shift`. + val localShift = shift + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + if (withStd) { + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + var i = 0 + while (i < size) { + values(i) = (values(i) - localShift(i)) * localFactor(i) + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } } - Vectors.fromBreeze(output) + Vectors.dense(values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) - case sv: BSV[Double] => + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + var i = 0 + while(i < size) { + values(i) *= localFactor(i) + i += 1 + } + Vectors.dense(values) + case sv: SparseVector => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + val indices = sv.indices + val values = sv.values.clone() + val nnz = values.size var i = 0 - while (i < output.data.length) { - output.data(i) *= factor(output.index(i)) + while (i < nnz) { + values(i) *= localFactor(indices(i)) i += 1 } - Vectors.fromBreeze(output) + Vectors.sparse(sv.size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else {