From f23d126a07629973266094144f1ae1099ed70e4d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Jan 2025 03:02:58 +0000 Subject: [PATCH] fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0 Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/utils.py | 13 ++++++++----- vllm/attention/backends/triton_mla.py | 4 ++-- vllm/envs.py | 8 +++++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index e3203aca1880f..6ad41ceee23e8 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -188,8 +188,9 @@ def _q_proj_and_k_up_proj(self, x): return torch.matmul(x, self.W_Q_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) else: - x = torch.matmul(x, self.W_Q) - return torch.matmul(x, self.W_UK.T)\ + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) def process_weights_after_loading(self): @@ -249,13 +250,15 @@ def process_weights_after_loading(self): self.W_UV_O.shape[0] * tp_size, self.W_UV_O.shape[1], bias=False, - #quant_config=self.o_proj.quant_method, TODO + # TODO(lucas) figure out how to properly forward quant_method + #quant_config=self.o_proj.quant_method, ) self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T) else: - print("Not absorbing weights") - self.W_UK, self.W_UV, self.W_Q = W_UK, W_UV, W_Q + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) @abstractmethod def _forward_prefill( diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index f52edea9dd9d3..43f5caf338b1f 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -124,7 +124,7 @@ def begin_forward(self, model_input): @dataclass(kw_only=True) class TritonMLAMetadata(MLAMetadataCommon): - """Metadata for FlashAttentionBackend. + """Metadata for TritonMLAMetadata. NOTE: Any python object stored here is not updated when it is cuda-graph replayed. If you have values that need to be changed @@ -189,7 +189,7 @@ class TritonMLAMetadata(MLAMetadataCommon): num_prefill_tokens: int - num_kv_splits: int = 4 + num_kv_splits: int = 4 # TODO(lucas) add heuristic attn_logits: Optional[torch.Tensor] = None req_idx: Optional[torch.Tensor] = None diff --git a/vllm/envs.py b/vllm/envs.py index bd6d97f331928..c8b7340c0d251 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -512,9 +512,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), - # Flag that can control whether - # - # + # Flag that can control whether or not we perform matrix-absorption for MLA + # decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the + # matrices reduces the runtime FLOPs needed to compute MLA but requires + # storing more weights, W_Q_UK and W_UV_O, so can increase memory usage, + # the is enabled by default "VLLM_MLA_PERFORM_MATRIX_ABSORPTION": lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))) }