Skip to content

Commit

Permalink
Use FP16-native PA after support in ROCm/aiter#97
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Feb 7, 2025
1 parent be2e940 commit 7428ffd
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions vllm/attention/ops/paged_attn_ater.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,8 @@ def forward_decode(
elif "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
else:
key_cache = key_cache.view(torch.int8)
value_cache = value_cache.view(torch.int8)
dtype=out.dtype
aiter.pa_fwd_asm(query.to(torch.bfloat16), key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale, v_scale,out)
if dtype==torch.float16:
# aiter.pa_fwd_as only support bf16 output for now
out.copy_(out.view(torch.bfloat16).to(torch.float16))
aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, seq_lens,
max_num_blocks_per_seq, k_scale, v_scale,out)
return out

@staticmethod
Expand Down

0 comments on commit 7428ffd

Please sign in to comment.