Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Flash attention reduces vulkan performance by ~50% #9572

Closed
tempstudio opened this issue Sep 21, 2024 · 5 comments
Closed

Bug: Flash attention reduces vulkan performance by ~50% #9572

tempstudio opened this issue Sep 21, 2024 · 5 comments
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@tempstudio
Copy link

What happened?

Enabling flash attention reduces performance on vulkan by a lot more than expected.
Even if performance varies between hardware, it feels like a 50% drop would be a bug

Hardware is AMD RX 6800 XT

Name and Version

version: 3772 (23e0d70)
built with MSVC 19.29.30154.0 for x64

What operating system are you seeing the problem on?

Windows

Relevant log output

llama-b3772-bin-win-vulkan-x64> ./llama-cli.exe -m '..\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf' -p "to be or" -n 600 -c 4096 -ngl 99
Performance without flash attention:
llama_perf_sampler_print:    sampling time =      48.42 ms /   604 runs   (    0.08 ms per token, 12474.70 tokens per second)
llama_perf_context_print:        load time =   13033.53 ms
llama_perf_context_print: prompt eval time =     183.59 ms /     4 tokens (   45.90 ms per token,    21.79 tokens per second)
llama_perf_context_print:        eval time =    9458.98 ms /   599 runs   (   15.79 ms per token,    63.33 tokens per second)
llama_perf_context_print:       total time =    9765.68 ms /   603 tokens

llama-b3772-bin-win-vulkan-x64> ./llama-cli.exe -m '..\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf' -p "to be or" -n 600 -c 4096 -ngl 99 --flash-attn
with flash attention:
llama_perf_sampler_print:    sampling time =      48.48 ms /   604 runs   (    0.08 ms per token, 12458.75 tokens per second)
llama_perf_context_print:        load time =    2709.09 ms
llama_perf_context_print: prompt eval time =     194.77 ms /     4 tokens (   48.69 ms per token,    20.54 tokens per second)
llama_perf_context_print:        eval time =   18321.90 ms /   599 runs   (   30.59 ms per token,    32.69 tokens per second)
llama_perf_context_print:       total time =   18617.86 ms /   603 tokens
@tempstudio tempstudio added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels Sep 21, 2024
@slaren
Copy link
Collaborator

slaren commented Sep 21, 2024

The Vulkan backend does not support flash attention, if you enable it then flash attention is run on the CPU, and a performance drop is expected.

@tempstudio
Copy link
Author

Thank you for the quick reply! Will close this.

@Mushoz
Copy link

Mushoz commented Dec 24, 2024

Found this issue as I was running into the same problem. @slaren is it possible to implement flash attention in the vulkan backend, and if so are there any plans to do so?

@slaren
Copy link
Collaborator

slaren commented Dec 24, 2024

See #10206

@Mushoz
Copy link

Mushoz commented Dec 24, 2024

I have seen that merged PR, but from my understanding that is Nvidia only through the VK_NV_cooperative_matrix2 extension. I was curious if it's possible and if there are any plans for a more generic implementation. Would a VK_KHR_cooperative_matrix implementation be possible for example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

No branches or pull requests

3 participants