From 7428ffd9324c1cc7a6351a51c66c1ff08697377a Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Fri, 7 Feb 2025 07:39:47 +0000 Subject: [PATCH] Use FP16-native PA after support in https://github.com/ROCm/aiter/pull/97 --- vllm/attention/ops/paged_attn_ater.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/paged_attn_ater.py b/vllm/attention/ops/paged_attn_ater.py index 3dbf3d5424216..1ac0a2738d8cf 100644 --- a/vllm/attention/ops/paged_attn_ater.py +++ b/vllm/attention/ops/paged_attn_ater.py @@ -156,14 +156,8 @@ def forward_decode( elif "fp8" in kv_cache_dtype: key_cache = key_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(torch.float8_e4m3fnuz) - else: - key_cache = key_cache.view(torch.int8) - value_cache = value_cache.view(torch.int8) - dtype=out.dtype - aiter.pa_fwd_asm(query.to(torch.bfloat16), key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale, v_scale,out) - if dtype==torch.float16: - # aiter.pa_fwd_as only support bf16 output for now - out.copy_(out.view(torch.bfloat16).to(torch.float16)) + aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, seq_lens, + max_num_blocks_per_seq, k_scale, v_scale,out) return out @staticmethod