Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Flash Attention 2 for webgpu EP (#23576)
### Description This change implements FlashAttention 2 for the webgpu EP for the MHA operator. Numbers from Alderlake device show a 2.2x speed up for prefill, which considering that Attention is 50% of prefill phase (other 50% being MatMul) implies 4x speed up for Attention with this implementation. This is inline with the expected perf gain of 2-4x with FlashAttention over regular attention. ``` Baseline PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 9.54997e+06 <<<<< avg (tokens/s): 104.817 p50 (us): 9.49218e+06 stddev (us): 251442 n: 5 * 1001 token(s) ------ With FlashAttention 2 PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000 Batch size: 1, prompt tokens: 1001, tokens to generate: 128 Prompt processing (time to first token): avg (us): 4.27937e+06 <<<<< avg (tokens/s): 233.913 p50 (us): 4.27687e+06 stddev (us): 5344.1 n: 5 * 1001 token(s) ``` ### Motivation and Context On integrated GPUs memory bandwidth is premium, Flash attention makes softmax computation (and therefore output attention vector computation) a running operation instead of maintaining full QKt attention scores in memory. As a result, we see significant improvements in prefill speed - 200% speed up measured here. This change uses techniques from co-operative matrix multiply to use registers from a subgroup for fast in register matrix multiply. Without the co-operative matrix multiply technique ALD showed about 6.0s prefill time. Tested on ALD/TGL intel integrated and Nvidia 4070. ### Future Work - Fine tuning and profiling optimizations. - Current implement is for prefill only, a generation phase optimized FA2 implementation is possible, however attention is a tiny part of the generation phase.
- Loading branch information