Skip to content

Commit

Permalink
more cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent c34e5ca commit f2cac91
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 36 deletions.
26 changes: 6 additions & 20 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
Expand Down Expand Up @@ -159,21 +159,6 @@ def __init__(
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashInferMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferMLAImpl")

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absored(
Expand Down Expand Up @@ -225,7 +210,7 @@ def process_weights_after_loading(self):

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
#
# Perform matrix-absorbtion following
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
Expand Down Expand Up @@ -292,14 +277,14 @@ def forward(
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for TritonMLAImpl")
"output is not yet supported for MLAImplBase")

is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None

if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for FlashInferMLAImpl")
"chunked prefill is not supported for MLAImplBase")

# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
Expand Down Expand Up @@ -355,7 +340,8 @@ def _forward_prefill_flash(

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than the
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _forward_decode(
dtype=q.dtype,
device=q.device)

# TODO(lucas) Allocate ahead of prefill
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
Expand Down
16 changes: 1 addition & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,6 @@
PretrainedConfig]]


def _is_flashinfer_available() -> bool:
"""Check if FlashInfer is available.
Returns:
bool: True if FlashInfer is installed and available, False otherwise.
"""
try:
from flashinfer import ( # noqa:F401
BatchDecodeMlaWithPagedKVCacheWrapper)
return True
except ImportError:
return False


class SupportsHash(Protocol):

def compute_hash(self) -> str:
Expand Down Expand Up @@ -832,7 +818,7 @@ def get_total_num_kv_heads(self) -> int:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if self.should_use_mla:
# TODO(simon): feature flag MLA
# When using MLA during decode it becomes MQA
return 1

total_num_kv_heads = self.get_total_num_kv_heads()
Expand Down

0 comments on commit f2cac91

Please sign in to comment.