From 8b03cb05f22d7b67aeda664f339e2077be3837bd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Jan 2025 01:19:33 +0000 Subject: [PATCH] different prefill and decode scales Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/triton_mla.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 4d185a1106344..2b44d6e152e8e 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,4 +1,3 @@ -import math from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -197,8 +196,6 @@ class TritonMLAMetadata(MLAMetadataCommon): # The dimension of the attention heads head_dim: Optional[int] = None - sm_scale: float = 0.0 - def __post_init__(self): supported_head_sizes = TritonMLABackend.get_supported_head_sizes() if self.head_dim is not None and self.head_dim \ @@ -207,11 +204,6 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - # Note(simon): for MLA: soft max scale needs to be - # `1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)`. - assert self.head_dim is not None - self.sm_scale = 1.0 / math.sqrt(self.head_dim + self.head_dim // 8) - @property def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: if self.num_prefills == 0: @@ -684,7 +676,7 @@ def _forward_decode( decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_meta.block_tables, decode_meta.seq_lens_tensor, attn_logits, - attn_metadata.num_kv_splits, decode_meta.sm_scale, + attn_metadata.num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o)