-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21389][ML][MLLIB] Optimize ALS recommendForAll by gemm #18624
Conversation
val m = srcIds.length | ||
val n = dstIds.length | ||
val dstIdMatrix = new Array[Int](m * num) | ||
val scoreMatrix = Array.fill[Double](m * num)(Double.MinValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, MinValue is not the most negative value, but the smallest positive value. Is that what you want here?
var size = pq.size | ||
while(size > 0) { | ||
size -= 1 | ||
val factor = pq.poll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
poll() because it has side effects
i += 1 | ||
// pq.size maybe less than num, corner case | ||
j += num | ||
pq.clear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clear()
k += 1 | ||
} | ||
var size = pq.size | ||
while(size > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to fix up a few style things like a space after while
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not add a nonEmpty / isEmpty method for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean add an isEmpty method for PriorityQueue? Thanks.
(index -> (srcIds, dstIdMatrix, new DenseMatrix(m, num, scoreMatrix))) | ||
} | ||
ratings.aggregateByKey(null: Array[Int], null: Array[Int], null: DenseMatrix)( | ||
(rateSum, rate) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Braces aren't needed in these args, just put them on one line
(rateSum1, rateSum2) => { | ||
mergeFunc(rateSum1, rateSum2, num) | ||
} | ||
).flatMap(value => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.flatMap { value =>
to avoid redundant parens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, use case (...)
instead of value to name its elements. The ._2, ._3 below is hard to understand
var rate_index = 0 | ||
while (j < num) { | ||
if (rate._3(i, rate_index) > rateSum._3(i, sum_index)) { | ||
tempIdMatrix(i * num + j) = rate._2(i * num + rate_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth storing i * num
in a local to avoid recomputing it
Test build #79584 has finished for PR 18624 at commit
|
Test build #79587 has finished for PR 18624 at commit
|
Test build #79597 has finished for PR 18624 at commit
|
} | ||
(index, (srcIds, dstIdMatrix, new DenseMatrix(m, num, scoreMatrix))) | ||
} | ||
ratings.aggregateByKey(null: Array[Int], null: Array[Int], null: DenseMatrix)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is aggregating by key
which in this case appears to be the "block index". What is the benefit then? Since each block will have a unique index, there would be no intermediate aggregation.
An user block, after Cartesian, will generate many blocks(Number of Item blocks), all these blocks should be aggregated. Thanks. |
var size = pq.size | ||
while (size > 0) { | ||
size -= 1 | ||
val factor = pq.poll() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really necessary to add poll
? For size of k
(which is usually very small), the approach of pq.foreach
should suffice and is simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The queue is length num
- which is typically10
, 20
, or perhaps in extreme cases in the low 100
's. So is there really any performance benefit here? Even if so it would be marginal and I believe it's cleaner do just use foreach
and sorted
, and not worth adding the poll
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When num = 20, if use sorted here, the prediction time is about 31s, if use poll, the prediction time is about 26s. I think this difference is large. I have tested many times. The result is about the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @MLnick , thanks for your review.
My original test for sorted is using: pq.toArray.sorted(Ordering.By(Int, Double), Double),
because pq.toArray.sorted(-_.2) build error. Maybe there is boxing/unboxing, the performance is very bad.
Now, I use pq.toArray.sortBy(-._2), the performance is good than poll. this 25s vs poll 26s.
Thanks.
We need the value is in order here. |
If no poll, we have to use toArray.sorted, which performance is bad. |
I have checked the results with the master method, the recommendation results are right. |
Test build #79655 has finished for PR 18624 at commit
|
Test build #79664 has finished for PR 18624 at commit
|
@mpjlu sorry for delay on this. I think the idea of the change is good but I still need to review in detail. One concern I have is it now looks quite convoluted, so I want to see if we can simplify the implementation somehow. |
@mpjlu also feel free to look into enhancing the tests. One issue however is we don't want to add to the run time too much as the ALS suite already is very heavy on time. |
Thanks @MLnick , I think the ML ALS suite is ok, just MLLIB ALS suite is too simple. One possible enhancement is to add the same test cases as ML ALS suite. How do you think about it? |
ping @WeichenXu123 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of my concern is, your change, increase the memory cost, at here:
L301 val ratings = srcFactors.transpose.multiply(dstFactors)
Suppose we tuning the blockSize
to some large value, the matrix multiplication here will be possible OOM.
And, I appreciate your optimization on aggregate merge stage. The sort & merge will be more efficient than merging via priority queue.
def blockify( | ||
rank: Int, | ||
features: RDD[(Int, Array[Double])]): RDD[(Array[Int], DenseMatrix)] = { | ||
val blockSize = 2000 // TODO: tune the block size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So will you add a parameter for this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have another PR to set this value SPARK-20443.
If the blockSize is large enough, it is possible to OOM. For my test, the blockSize is set from 1000 to 8000, the performance of this PR is better than the master.
And the performance is about the same for blockSize is 1000 to 8000.
But, I agree the issue @MLnick mentioned, the code now looks convoluted, can you try to simplify it ? |
Thanks @WeichenXu123 , I will think about the method to simplify the code. |
Hi @holdenk, this is the PR we have discussed in Strata conference. I have thought about the code again, for the performance, we can continue to optimize the code. Because we can merge the block matrices before shuffle. |
Test build #84749 has finished for PR 18624 at commit
|
8ac8196
to
f36706a
Compare
Test build #84750 has finished for PR 18624 at commit
|
retest this please |
Test build #84752 has finished for PR 18624 at commit
|
Because I don't have the environment to continue this work, I will close it. Thanks. |
What changes were proposed in this pull request?
In Spark 2.2, we have optimized ALS recommendForAll, which uses a handwriting matrix multiplication, and get the topK items for each matrix. The method effectively reduce the GC problem. However, Native BLAS GEMM, like Intel MKL, and OpenBLAS, the performance of matrix multiplication is about 10X comparing with handwriting method.
I have rewritten the code of recommendForAll with GEMM, and got about 50% improvement comparing with the master recommendForAll method.
The key point of this optimization:
1), use GEMM to replace hand-written matrix multiplication.
2), Use matrix to keep temp result, largely reduce GC and computing time. The master method create many small objects, which causes using GEMM directly cannot get good performance.
3), Use sort and merge to get the topK items, which don't need to call priority queue two times.
Test Result:
479818 users, 13727 products, rank = 10, topK = 20.
3 workers, each with 35 cores. Native BLAS is Intel MKL.
Block Size: 1000===2000===4000===8000
Master Method:40s-----39.4s-----39.5s----39.1s
This Method 26.5s---25.9s----26s-----27.1s
Performance Improvement: (OldTime - NewTime)/NewTime = about 50%
How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.