-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-20587][ML] Improve performance of ML ALS recommendForAll #17845
Conversation
cc @mpjlu Also @srowen @sethah @jkbradley |
Some quick perf numbers: Using
So 23-37x improvement. |
Test build #76424 has finished for PR 17845 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass on style.
score += srcFactor(k) * dstFactor(k) | ||
k += 1 | ||
} | ||
pq += { (dstId, score) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pq += dstId -> score
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
@@ -389,6 +436,17 @@ class ALSModel private[ml] ( | |||
) | |||
recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is discouraged within Spark: https://github.com/databricks/scala-style-guide#infix-methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point - may as well fix it while here
*/ | ||
private def blockify( | ||
factors: Dataset[(Int, Array[Float])], | ||
/* TODO make blockSize a param? */blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just put the comment in the doc and reference a JIRA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
* relatively efficient, the approach implemented here is significantly more efficient. | ||
* | ||
* This approach groups factors into blocks and computes the top-k elements per block, | ||
* using Level 1 BLAS (dot) and an efficient [[BoundedPriorityQueue]]. It then computes the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below we say that blas is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about "... using dot product instead of gemm and an efficient ..."
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) | ||
srcIter.foreach { case (srcId, srcFactor) => | ||
dstIter.foreach { case (dstId, dstFactor) => | ||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use doc notation. Maybe we can reduce it to:
/*
* The below code is equivalent to
* `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)`
* The handwritten version is as or more efficient as BLAS calls in this case.
*/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
Thanks @sethah will update shortly |
Test build #76447 has finished for PR 17845 at commit
|
jenkins retest this please |
Test build #76571 has finished for PR 17845 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; thanks for doing this! Feel free to merge or address my 1 comment
val m = srcIter.size | ||
val n = math.min(dstIter.size, num) | ||
val output = new Array[(Int, Int, Float)](m * n) | ||
var j = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You could combine j and i; you really just need 1 counter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
j
iterates through src
ids while i
iterates through dst
ids in the queue for each src
id. So I don't think they can be combined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway the iter.next()
code is a bit ugly and since it's at most k
elements it's not really performance critical, so could just use foreach
I think
One more comment I'll copy from the other PR: I'm not a fan of custom BLAS implementations scattered throughout MLlib. Could you please follow up by putting the dot as a private API in BLAS.scala and adding unit tests? |
Merged to master/branch-2.2 Thanks @mpjlu for the original work on the approach! |
This PR is a `DataFrame` version of #17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes #17845 from MLnick/ml-als-perf. (cherry picked from commit 10b00ab) Signed-off-by: Nick Pentreath <nickp@za.ibm.com>
Small clean ups from #17742 and #17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes #17919 from MLnick/SPARK-20677-als-perf-followup. (cherry picked from commit 25b4f41) Signed-off-by: Nick Pentreath <nickp@za.ibm.com>
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
This PR is a `DataFrame` version of apache#17742 for [SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving the performance of `recommendAll` methods. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes apache#17845 from MLnick/ml-als-perf.
Small clean ups from apache#17742 and apache#17845. ## How was this patch tested? Existing unit tests. Author: Nick Pentreath <nickp@za.ibm.com> Closes apache#17919 from MLnick/SPARK-20677-als-perf-followup.
This PR is a
DataFrame
version of #17742 for SPARK-11968, for improving the performance ofrecommendAll
methods.How was this patch tested?
Existing unit tests.