Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent 8bdc14a commit 09d814c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
16 changes: 16 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,19 @@ def forward(
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError


class MLAAttentionImpl(AttentionImpl):

@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
9 changes: 5 additions & 4 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionMetadata)
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
Expand All @@ -18,11 +19,11 @@
@dataclass(kw_only=True)
class MLAMetadataCommon(AttentionMetadata):
# Input positions for rotrary embeddings since for MLA the rotarty
# position encoding
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor


class MLAImplCommon(AttentionImpl):
class MLACommonImpl(MLAAttentionImpl):
"""
Common class for implementing repeated parts
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.utils import MLAImplCommon, MLAMetadataCommon
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
Expand Down Expand Up @@ -585,7 +585,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)


class TritonMLAImpl(MLAImplCommon):
class TritonMLAImpl(MLACommonImpl):

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class DeepseekV2MLAAttention(nn.Module):
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLAImplCommon in: vllm/attention/backends/mla/utils.py
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""

def __init__(
Expand Down

0 comments on commit 09d814c

Please sign in to comment.