-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV #30468
Conversation
use array agg
dataset: ml-latest/ratings.csv env: Ubuntu 20.04 train (in 2.4.7):
test code:
|
GEMM:
different from previous impl (like #18624) which use However, it is still slower than existing impl. |
GEMV:
Then I switch to GEMV, which brings siginificent speedup. The size of |
Then I try to directly aggregate on arrays (containing topK elements) obtained in each gemv. But it does not bring visible improvement. |
Finally, I found that Guava.Ordering seems much more efficient than
Compared to |
It looks like that: Above tests are done locally, since I do not have a clean cluster for now. friendly ping @srowen @MLnick @mpjlu @jkbradley @mengxr @WeichenXu123, because of your comments in previous prs (#17742, #17845, #18624) |
Kubernetes integration test starting |
Test build #131546 has finished for PR 30468 at commit
|
Kubernetes integration test status success |
b1: (Array[Int], Array[Float]), | ||
b2: (Array[Int], Array[Float])): (Array[Int], Array[Float]) = { | ||
val (ids1, scores1) = b1 | ||
val (ids2, socres2) = b2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: typo in socres2
} | ||
pq.clear() | ||
} ++ { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused by this part - why null these out? the flatMap in which they are declared is done here. Maybe I misread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is to make sure buffer
is marked ready for GC, but it doesn't matter, I will remove it.
@@ -551,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] { | |||
model | |||
} | |||
} | |||
|
|||
/** select top indices based on values. */ | |||
private[recommendation] class TopSelector(val values: Array[Float]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be a class? looks like this code is called once. May be less code/indirection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will consider whether a class is needed.
* Works on rows of the form (ScrId, DstIds, Scores). | ||
* Finds the top `num` DstIds and Scores. | ||
*/ | ||
private[recommendation] class TopKArrayAggregator(num: Int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love the copy-paste - are the other uses of the class above able to use your new idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since I just find that aggregating on array bring little improvement, so I will remove this class
Test build #131585 has finished for PR 30468 at commit
|
back to TopByKeyAggregator
commit 543a41f: scala> val start = System.currentTimeMillis; scala> model.recommendForAllUsers(10).count scala> model.recommendForAllItems(10).count scala> val end = System.currentTimeMillis; scala> end - start I also try using
scala> val start = System.currentTimeMillis; scala> model.recommendForAllUsers(10).count scala> model.recommendForAllItems(10).count scala> val end = System.currentTimeMillis; scala> end - start |
Test build #131623 has finished for PR 30468 at commit
|
retest this please |
Test build #131626 has finished for PR 30468 at commit
|
Test build #131629 has finished for PR 30468 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks OK pending tests
Jenkins retest this please |
Test build #131684 has finished for PR 30468 at commit
|
retest this please |
Test build #131711 has finished for PR 30468 at commit
|
I think it's OK to merge when you're ready |
performance test between
results: n=512, k=5: pq0=1040, pq1=823, ord0=1236, ord1=726 It seems that |
Seems fine |
I will leave this PR open now, and will merge it to 3.2.0 if no more comments. |
I think it's fine to merge - should I? |
@srowen I think it is ready to merge. |
Jenkins retest this please |
Kubernetes integration test starting |
Test build #132979 has finished for PR 30468 at commit
|
Kubernetes integration test status failure |
Merged to master (3.2.0) |
### What changes were proposed in this pull request? 1, update related doc; 2, MatrixFactorizationModel use GEMV; ### Why are the changes needed? see performance gain in #30468 ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing testsuites Closes #31279 from zhengruifeng/als_follow_up. Authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
### What changes were proposed in this pull request? 1, update related doc; 2, MatrixFactorizationModel use GEMV; ### Why are the changes needed? see performance gain in apache#30468 ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? existing testsuites Closes apache#31279 from zhengruifeng/als_follow_up. Authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
What changes were proposed in this pull request?
There were a lot of works on improving ALS's recommendForAll
For now, I found that it maybe futhermore optimized by
1, using GEMV and sharing a pre-allocated buffer per task;
2, using guava.ordering instead of BoundedPriorityQueue;
Why are the changes needed?
In my test, using
f2jBLAS.sgemv
, it is about 2.3X faster than existing impl.Does this PR introduce any user-facing change?
No
How was this patch tested?
existing testsuites