Skip to content

Commit

Permalink
Implement Flash Attention 2 for webgpu EP (#23576)
Browse files Browse the repository at this point in the history
### Description
This change implements FlashAttention 2 for the webgpu EP for the MHA
operator.

Numbers from Alderlake device show a 2.2x speed up for prefill, which
considering that Attention is 50% of prefill phase (other 50% being
MatMul) implies 4x speed up for Attention with this implementation. This
is inline with the expected perf gain of 2-4x with FlashAttention over
regular attention.

```
Baseline
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       9.54997e+06   <<<<<
        avg (tokens/s): 104.817
        p50 (us):       9.49218e+06
        stddev (us):    251442
        n:              5 * 1001 token(s)
------
With FlashAttention 2
PS C:\onnxruntime> C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web\ -l 1000
Batch size: 1, prompt tokens: 1001, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       4.27937e+06     <<<<<
        avg (tokens/s): 233.913
        p50 (us):       4.27687e+06
        stddev (us):    5344.1
        n:              5 * 1001 token(s)
```

### Motivation and Context

On integrated GPUs memory bandwidth is premium, Flash attention makes
softmax computation (and therefore output attention vector computation)
a running operation instead of maintaining full QKt attention scores in
memory. As a result, we see significant improvements in prefill speed -
200% speed up measured here.

This change uses techniques from co-operative matrix multiply to use
registers from a subgroup for fast in register matrix multiply. Without
the co-operative matrix multiply technique ALD showed about 6.0s prefill
time.

Tested on ALD/TGL intel integrated and Nvidia 4070.

### Future Work
- Fine tuning and profiling optimizations.
- Current implement is for prefill only, a generation phase optimized
FA2 implementation is possible, however attention is a tiny part of the
generation phase.
  • Loading branch information
sushraja-msft authored Feb 7, 2025
1 parent a6ea57b commit 82840f6
Show file tree
Hide file tree
Showing 3 changed files with 525 additions and 0 deletions.
Loading

0 comments on commit 82840f6

Please sign in to comment.