Skip to content

Commit

Permalink
[Bugfix] Massage MLA's usage of flash attn for RoCM (vllm-project#13310)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored and kerthcet committed Feb 21, 2025
1 parent a1f6419 commit 61d1382
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple
Expand Down Expand Up @@ -183,6 +184,15 @@ def __init__(
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()

# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
Expand Down Expand Up @@ -487,7 +497,7 @@ def _forward_prefill_flash(
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

attn_output = flash_attn_varlen_func(
attn_output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
Expand All @@ -497,7 +507,6 @@ def _forward_prefill_flash(
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
fa_version=self.vllm_flash_attn_version,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
Expand Down

0 comments on commit 61d1382

Please sign in to comment.