Skip to content

Commit

Permalink
[SPARK-30660][ML][PYSPARK] LinearRegression blockify input vectors
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, use blocks instead of vectors for performance improvement
2, use Level-2 BLAS
3, move standardization of input vectors outside of gradient computation

### Why are the changes needed?
1, less RAM to persist training data; (save ~40%)
2, faster than existing impl; (30% ~ 102%)

### Does this PR introduce any user-facing change?
add a new expert param `blockSize`

### How was this patch tested?
updated testsuites

Closes apache#27396 from zhengruifeng/blockify_lireg.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
zhengruifeng authored and srowen committed Feb 1, 2020
1 parent 2fd15a2 commit d0c3e9f
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package org.apache.spark.ml.optim.aggregator

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg._

/**
* HuberAggregator computes the gradient and loss for a huber loss function,
Expand Down Expand Up @@ -62,27 +62,26 @@ import org.apache.spark.ml.linalg.Vector
*
* @param fitIntercept Whether to fit an intercept term.
* @param epsilon The shape parameter to control the amount of robustness.
* @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param bcParameters including three parts: the regression coefficients corresponding
* to the features, the intercept (if fitIntercept is ture)
* and the scale parameter (sigma).
*/
private[ml] class HuberAggregator(
numFeatures: Int,
fitIntercept: Boolean,
epsilon: Double,
bcFeaturesStd: Broadcast[Array[Double]])(bcParameters: Broadcast[Vector])
extends DifferentiableLossAggregator[Instance, HuberAggregator] {
epsilon: Double)(bcParameters: Broadcast[Vector])
extends DifferentiableLossAggregator[InstanceBlock, HuberAggregator] {

protected override val dim: Int = bcParameters.value.size
private val numFeatures: Int = if (fitIntercept) dim - 2 else dim - 1
private val sigma: Double = bcParameters.value(dim - 1)
private val intercept: Double = if (fitIntercept) {
bcParameters.value(dim - 2)
} else {
0.0
}
// make transient so we do not serialize between aggregation stages
@transient private lazy val coefficients = bcParameters.value.toArray.slice(0, numFeatures)
@transient private lazy val linear =
new DenseVector(bcParameters.value.toArray.take(numFeatures))

/**
* Add a new training instance to this HuberAggregator, and update the loss and gradient
Expand All @@ -98,16 +97,13 @@ private[ml] class HuberAggregator(
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")

if (weight == 0.0) return this
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = coefficients
val localCoefficients = linear.values
val localGradientSumArray = gradientSumArray

val margin = {
var sum = 0.0
features.foreachNonZero { (index, value) =>
if (localFeaturesStd(index) != 0.0) {
sum += localCoefficients(index) * (value / localFeaturesStd(index))
}
sum += localCoefficients(index) * value
}
if (fitIntercept) sum += intercept
sum
Expand All @@ -119,10 +115,7 @@ private[ml] class HuberAggregator(
val linearLossDivSigma = linearLoss / sigma

features.foreachNonZero { (index, value) =>
if (localFeaturesStd(index) != 0.0) {
localGradientSumArray(index) +=
-1.0 * weight * linearLossDivSigma * (value / localFeaturesStd(index))
}
localGradientSumArray(index) -= weight * linearLossDivSigma * value
}
if (fitIntercept) {
localGradientSumArray(dim - 2) += -1.0 * weight * linearLossDivSigma
Expand All @@ -134,10 +127,7 @@ private[ml] class HuberAggregator(
(sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon)

features.foreachNonZero { (index, value) =>
if (localFeaturesStd(index) != 0.0) {
localGradientSumArray(index) +=
weight * sign * epsilon * (value / localFeaturesStd(index))
}
localGradientSumArray(index) += weight * sign * epsilon * value
}
if (fitIntercept) {
localGradientSumArray(dim - 2) += weight * sign * epsilon
Expand All @@ -149,4 +139,75 @@ private[ml] class HuberAggregator(
this
}
}

/**
* Add a new training instance block to this HuberAggregator, and update the loss and gradient
* of the objective function.
*
* @param block The instance block of data point to be added.
* @return This HuberAggregator object.
*/
def add(block: InstanceBlock): HuberAggregator = {
require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
require(block.weightIter.forall(_ >= 0),
s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")

if (block.weightIter.forall(_ == 0)) return this
val size = block.size
val localGradientSumArray = gradientSumArray

// vec here represents margins or dotProducts
val vec = if (fitIntercept && intercept != 0) {
new DenseVector(Array.fill(size)(intercept))
} else {
new DenseVector(Array.ofDim[Double](size))
}

if (fitIntercept) {
BLAS.gemv(1.0, block.matrix, linear, 1.0, vec)
} else {
BLAS.gemv(1.0, block.matrix, linear, 0.0, vec)
}

// in-place convert margins to multipliers
// then, vec represents multipliers
var i = 0
while (i < size) {
val weight = block.getWeight(i)
if (weight > 0) {
weightSum += weight
val label = block.getLabel(i)
val margin = vec(i)
val linearLoss = label - margin

if (math.abs(linearLoss) <= sigma * epsilon) {
lossSum += 0.5 * weight * (sigma + math.pow(linearLoss, 2.0) / sigma)
val linearLossDivSigma = linearLoss / sigma
val multiplier = -1.0 * weight * linearLossDivSigma
vec.values(i) = multiplier
localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - math.pow(linearLossDivSigma, 2.0))
} else {
lossSum += 0.5 * weight *
(sigma + 2.0 * epsilon * math.abs(linearLoss) - sigma * epsilon * epsilon)
val sign = if (linearLoss >= 0) -1.0 else 1.0
val multiplier = weight * sign * epsilon
vec.values(i) = multiplier
localGradientSumArray(dim - 1) += 0.5 * weight * (1.0 - epsilon * epsilon)
}
} else {
vec.values(i) = 0.0
}
i += 1
}

val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures))
BLAS.gemv(1.0, block.matrix.transpose, vec, 0.0, linearGradSumVec)
linearGradSumVec.foreachNonZero { (i, v) => localGradientSumArray(i) += v }
if (fitIntercept) {
localGradientSumArray(dim - 2) += vec.values.sum
}

this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package org.apache.spark.ml.optim.aggregator

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg._

/**
* LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
Expand Down Expand Up @@ -157,26 +157,25 @@ private[ml] class LeastSquaresAggregator(
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector])
extends DifferentiableLossAggregator[Instance, LeastSquaresAggregator] {
bcFeaturesStd: Broadcast[Vector],
bcFeaturesMean: Broadcast[Vector])(bcCoefficients: Broadcast[Vector])
extends DifferentiableLossAggregator[InstanceBlock, LeastSquaresAggregator] {
require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " +
s"deviation to be positive.")

private val numFeatures = bcFeaturesStd.value.length
private val numFeatures = bcFeaturesStd.value.size
protected override val dim: Int = numFeatures
// make transient so we do not serialize between aggregation stages
@transient private lazy val featuresStd = bcFeaturesStd.value
@transient private lazy val effectiveCoefAndOffset = {
val coefficientsArray = bcCoefficients.value.toArray.clone()
val featuresMean = bcFeaturesMean.value
val featuresStd = bcFeaturesStd.value
var sum = 0.0
var i = 0
val len = coefficientsArray.length
while (i < len) {
if (featuresStd(i) != 0.0) {
coefficientsArray(i) /= featuresStd(i)
sum += coefficientsArray(i) * featuresMean(i)
sum += coefficientsArray(i) / featuresStd(i) * featuresMean(i)
} else {
coefficientsArray(i) = 0.0
}
Expand All @@ -186,7 +185,7 @@ private[ml] class LeastSquaresAggregator(
(Vectors.dense(coefficientsArray), offset)
}
// do not use tuple assignment above because it will circumvent the @transient tag
@transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1
@transient private lazy val effectiveCoefficientsVec = effectiveCoefAndOffset._1
@transient private lazy val offset = effectiveCoefAndOffset._2

/**
Expand All @@ -204,21 +203,64 @@ private[ml] class LeastSquaresAggregator(

if (weight == 0.0) return this

val diff = BLAS.dot(features, effectiveCoefficientsVector) - label / labelStd + offset
val localEffectiveCoefficientsVec = effectiveCoefficientsVec

val diff = {
var dot = 0.0
features.foreachNonZero { (index, value) =>
dot += localEffectiveCoefficientsVec(index) * value
}
dot - label / labelStd + offset
}

if (diff != 0) {
val localGradientSumArray = gradientSumArray
val localFeaturesStd = featuresStd
features.foreachNonZero { (index, value) =>
val fStd = localFeaturesStd(index)
if (fStd != 0.0) {
localGradientSumArray(index) += weight * diff * value / fStd
}
localGradientSumArray(index) += weight * diff * value
}
lossSum += weight * diff * diff / 2.0
}
weightSum += weight
this
}
}

/**
* Add a new training instance block to this LeastSquaresAggregator, and update the loss
* and gradient of the objective function.
*
* @param block The instance block of data point to be added.
* @return This LeastSquaresAggregator object.
*/
def add(block: InstanceBlock): LeastSquaresAggregator = {
require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
require(block.weightIter.forall(_ >= 0),
s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0")

if (block.weightIter.forall(_ == 0)) return this
val size = block.size

// vec here represents diffs
val vec = new DenseVector(Array.tabulate(size)(i => offset - block.getLabel(i) / labelStd))
BLAS.gemv(1.0, block.matrix, effectiveCoefficientsVec, 1.0, vec)

// in-place convert diffs to multipliers
// then, vec represents multipliers
var i = 0
while (i < size) {
val weight = block.getWeight(i)
val diff = vec(i)
lossSum += weight * diff * diff / 2
weightSum += weight
val multiplier = weight * diff
vec.values(i) = multiplier
i += 1
}

val gradSumVec = new DenseVector(gradientSumArray)
BLAS.gemv(1.0, block.matrix.transpose, vec, 1.0, gradSumVec)

this
}
}
Loading

0 comments on commit d0c3e9f

Please sign in to comment.