-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal kernel mv_f16_f32_l4 performance issue for long contexts, too many threads #6089
Comments
Thank you for looking into this! Any help with optimizing the Metal kernels would be appreciated. I myself don't even know how to use the profiling tools, so it's possible that there is a lot of room for optimizations
Lines 1353 to 1379 in b5f4ae0
It is disabled by default because I don't know how to decide when the mat-vec kernel is better than the mat-mat kernel (#3524 (comment)) Anyway, if you have suggestions and ideas to improve the performance, feel free to open PRs |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
Prerequisites
Please answer the following questions for yourself before submitting an issue.
Feature Description
This is more a question if anyone is/was looking at the following long context performance issue in Metal. I did not find anything in the repo history, but maybe I just missed it.
When profiling long contexts (starting from about ~25K tokens), I found that block processing latency started being dominated by kernel_mul_mv_f16_f32_l4(width=small, like 128, height=large, like 32768 for context length slightly smaller than 32k, input vector length=32768). Running this kernel takes ~80% of total time, and this runtime is dominated by very low GPU execution units utilization caused by 32768 threads, each running very small chunk of work. There is no memory pressure.
So the code
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
and kernel can be somehow optimized not to start number of threads == vector length, but chunk the work differently
By the way, in this matrix to vector f16_f32_l4 kernel, lines
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 ) (src1 + r1nb11 + im*nb12);
..
}
are redundant as it is always called with nrows == 1, so when I replaced it with just
device const float4 * y4 = (device const float4 ) (src1 + imnb12);
nothing has changed (except the code became a tiny bit cleaner).
Motivation
Significant performance drop for long context prompts in Metal is caused by inefficient Metal threads scheduling, once fixed, I expect smaller time increase for longer contexts.
E.g. that is what I measured for one of the common models running on M3 Max:
context:323, t/s: 7.2
context:2248, t/s: 6.3
context:5908, t/s: 5.1
context:10314, t/s: 4.4
context:15112, t/s: 3.65
context:20556, t/s: 3.1
context:24588, t/s: 3
Possible Implementation
Change kernel_mul_mv_f16_f32_l4, or possibly add kernel_mul_mv_f16_f32_l4_long_vector with different threadGroups and threadsPerThreadgroup thread/blocks count
The text was updated successfully, but these errors were encountered: