-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Code snippet runs 4x slower since mlx 0.19.0 #1918
Comments
It runs a lot slower on M3 Max from 0.18.1 to 0.19.0 🤔 |
It must be related to the addition of the fused op #1497 and it must be hitting an unusual case (since your doing some slicing + transposes on the KV). |
The problem is the transpose: kv[...].transpose(0, 2, 1, 3) incurs a copy of both the keys and values before calling the kernel and that slows things down a lot. I'm not sure if we can support transposed keys natively in the op .. maybe. Usually the KV cache does not need a transpose because you are updating an already transposed KV cache with the new key. Is your actual computation quite different from that? |
On the minimal reproducible example, the bug can be avoided by adding kv = mx.array(np.array(kv.astype(mx.float32))).astype(mx.bfloat16) just before the first eval. Though it's quite counter-intuitive, it makes the bug go away, and we get to 6ms per step with mlx 0.23.1. A similar method was used on the full size model and achieve a speedup on mlx 0.23.1. But while the full size model was faster with this trick, it was not as fast as when using mlx v0.18.1 (0.18.1 is still ~20% faster than 0.23.1). So it seems there are other things at play here. I could once again try to reduce the whole model into a minimal reproducible example with the numpy trick, but it's quite time consuming, so maybe I could open a new issue once this one is fixed? |
It makes sense.. since it results in the What would be helpful is if you could share more about how your actual computation works (not the toy example). The reason being, we want to make sure we optimize the right thing here / find an optimal solution. In your actual computation do you use a growing KV cache? If so, the simplest fix would be to store that already transposed. If not, what are you doing instead? Is it more like an encoder / decoder model with cross attention to the |
Describe the bug
When developping our model, we noticed that changing the version of mlx had a huge impact on performance, notably that mlx 0.18.1 was faster than all mlx versions released afterwards.
To Reproduce
Here is a minimal reproducible example:
Expected behavior
When a new MLX version is released, it should be as fast or faster than the previous versions for any code written.
Desktop (please complete the following information):
Model Name: MacBook Air
Model Identifier: Mac15,12
Chip: Apple M3
Total Number of Cores: 8 (4 performance and 4 efficiency)
Memory: 16 GB
System Firmware Version: 10151.81.1
OS Loader Version: 10151.81.1
ProductName: macOS
ProductVersion: 14.3
BuildVersion: 23D2057
As previously mentionned, the mlx version influences the average time per step, with more recent version performing worse. I run the script with a given version by running
uv add --script bench.py mlx==0.23.1 && uv run bench.py
.Additional information
I also reproduced this slowdown on a mac Mini (Mac16,11), with 24gb of ram and 12 cores. The slowdown is then between 4x and 7x depending on the version.
The text was updated successfully, but these errors were encountered: