-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
metal : optimize FA kernels #10171
metal : optimize FA kernels #10171
Conversation
c71e0bc
to
d0cff71
Compare
a797e5d
to
f66d362
Compare
ff1b4f5
to
5464b08
Compare
This PR should be gucci now. |
The performance increase looks about the same with M3 Max:
Note that to test batch sizes larger than the default 2048 with |
a49913f
to
5d1a10d
Compare
ggml-ci
59792ff
to
1888c1f
Compare
Thanks, I forgot about that. As a data point, running some tests as a function of the ./llama-bench -m ./models/llama-3.2-3b-instruct/ggml-model-f16.gguf -fa 1 -p 1024,2048,4096,8192,16384 -b 16384 -ub 512,1024,2048,4096,8192 -n 0
build: 59792ff (4057) ./llama-bench -m ./models/qwen2.5-7b-coder/ggml-model-q8_0.gguf -fa 1 -p 1024,2048,4096,8192,16384 -b 16384 -ub 512,1024,2048,4096,8192 -n 0
build: 1888c1f (4057) My guess is that the logic for skipping the computation of attention blocks when the mask is full of -INF in that block is now more efficient. I'm wondering if this optimization could be viable for the CUDA FA as well. |
1888c1f
to
bc143ec
Compare
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
tgt #10149
rel #8439
Various optimizations for the FA kernels:
The performance should be noticeably better at larger contexts. The kernels continue to use F32 accumulators for the
Q*K*scale
so I hope there are no floating-point range issues. Though some extra testing won't hurt.The original idea of using full
BF16
math in the FA kernels did not produce satisfactory results. I think thatbfloat
performance is not great on Metal yet.Here are some benches:
Using
llama-batched-bench
to show TG speed after large prompts (S_TG
column):master
gg/metal-fa-f16
M1 Pro
TODO: