Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6685][MLLIB]Use DSYRK to compute AtA in ALS #13891

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -635,6 +636,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(): Unit = {
var ii = 0
for(i <- 0 until k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might fail the style check for missing space before the paren.
Also, I think the received wisdom is that for loops are slow in Scala? if this is performance-critical, you may convert to while loops. Also you can cache i*k from the outer loop below in the inner loop

for(j <- 0 to i) {
ata(ii) += ata2(i * k + j)
ii += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -647,6 +657,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -658,6 +677,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand Down Expand Up @@ -1296,6 +1316,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
val doStack = if (srcPtrs(j + 1) - srcPtrs(j) > 10) true else false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (...) true else false is redundant

val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1313,11 +1336,23 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer += rating
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down