Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] Deepseek v3 MLA support with FP8 compute #12601

Merged
merged 41 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27ad92c
squashed commits
LucasWilkinson Jan 30, 2025
c34e5ca
fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0
LucasWilkinson Jan 30, 2025
f2cac91
more cleanups
LucasWilkinson Jan 30, 2025
068e672
Update utils.py
LucasWilkinson Jan 30, 2025
31b802c
Update vllm/attention/backends/mla/utils.py
LucasWilkinson Jan 30, 2025
634eee6
review comments
LucasWilkinson Jan 30, 2025
7487429
renaming for consistency
LucasWilkinson Jan 30, 2025
d27826d
Update vllm/config.py
LucasWilkinson Jan 30, 2025
8bdc14a
review comments
LucasWilkinson Jan 30, 2025
09d814c
review comments
LucasWilkinson Jan 30, 2025
4a46014
Update vllm/attention/backends/mla/utils.py
LucasWilkinson Jan 30, 2025
0881475
disable MLA for v3 for now
LucasWilkinson Jan 30, 2025
37e39f4
fix failing test
LucasWilkinson Jan 30, 2025
cfb2d26
fix mypy
LucasWilkinson Jan 30, 2025
5afc1bf
fix mypy
LucasWilkinson Jan 30, 2025
54ba87d
add cuda graph support
LucasWilkinson Jan 30, 2025
31c34bf
ci fix
LucasWilkinson Jan 30, 2025
433322b
Revert "add cuda graph support"
LucasWilkinson Jan 31, 2025
f2b2500
Fix TP > 1 cuda graphs
LucasWilkinson Jan 31, 2025
2d61054
cleanup
LucasWilkinson Jan 31, 2025
645622c
cleanup
LucasWilkinson Jan 31, 2025
0ccbcce
deepseek v3 support
LucasWilkinson Jan 31, 2025
076cbe5
Merge branch 'main' into mla-fp8
LucasWilkinson Jan 31, 2025
a57cd3d
Merge branch 'main' of github.com:vllm-project/vllm into mla-fp8
simon-mo Jan 31, 2025
548ec44
simon changes
LucasWilkinson Jan 31, 2025
3d12a04
working but messy
LucasWilkinson Jan 31, 2025
f51cbe0
review comments
LucasWilkinson Jan 31, 2025
3cdd2ce
cleanup
LucasWilkinson Jan 31, 2025
c9d72cb
more cleanup
LucasWilkinson Jan 31, 2025
4251506
fixes
LucasWilkinson Jan 31, 2025
9829fae
misc
LucasWilkinson Jan 31, 2025
1621381
Update vllm/model_executor/model_loader/loader.py
LucasWilkinson Jan 31, 2025
e144da8
Update vllm/model_executor/model_loader/loader.py
LucasWilkinson Jan 31, 2025
db2c583
filter compressed tensor models better
LucasWilkinson Feb 1, 2025
fac827f
Merge remote-tracking branch 'origin/main' into mla-fp8
LucasWilkinson Feb 1, 2025
5002734
simplification
LucasWilkinson Feb 1, 2025
0d66687
Update loader.py
simon-mo Feb 1, 2025
5fe1d1d
format
LucasWilkinson Feb 1, 2025
5d5071c
reduce split kv amount
LucasWilkinson Feb 1, 2025
7ac6f52
fix none type error
LucasWilkinson Feb 1, 2025
dc0e2af
ci fix
LucasWilkinson Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 184 additions & 36 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional
from typing import Any, Dict, Generic, List, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStrategy

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

Expand All @@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):

class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
Common class for implementing repeated parts
Common class for implementing repeated parts

Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
Expand All @@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q

#
# Outside the MLA attention backend
#
Expand All @@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.

#
# Inside the MLA attention backend
#

* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).

3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
Expand Down Expand Up @@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.

"""

def __init__(
Expand Down Expand Up @@ -162,15 +174,32 @@ def __init__(

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed(
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]

def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
Expand All @@ -179,8 +208,91 @@ def _q_proj_and_k_up_proj(self, x):
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self):
kv_b_proj_weight = self.kv_b_proj.weight.T
def process_weights_after_loading(self, act_dtype: torch.dtype):

def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))

def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)

# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fp8LinearMethod.block_quant is a boolean, is there meant to be a check for False instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is a bug, I fixed it here #13181

weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)

def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale

def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)

return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight

if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype

kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
Expand All @@ -198,18 +310,35 @@ def process_weights_after_loading(self):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)

q_proj = self.q_proj.weight.T\
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)

# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()

# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape

#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
Expand All @@ -223,25 +352,44 @@ def process_weights_after_loading(self):
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()

W_O = self.o_proj.weight\
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)

W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()

tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it too onerous to construct a quant method here? (i.e. should we try to make this easier in the future?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya because you have other make a quant_config and stuff and all the weight names are different, and for per-channel we enter the compressed tensors world which is another can of worms :/

agreed we should try to make it easier in the future

)

self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)

self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")

self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
Expand Down
18 changes: 7 additions & 11 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,12 @@ def get_state_cls() -> Type["TritonMLAState"]:

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
kv_lora_rank: int, # passed via head_size
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
k_pe_size = kv_lora_rank // 8
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
return (num_blocks, block_size, head_size)

@staticmethod
def swap_blocks(
Expand All @@ -83,7 +81,7 @@ def copy_blocks(

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [512]
return [576]


class TritonMLAState(AttentionState):
Expand Down Expand Up @@ -624,8 +622,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.multimodal_placeholder_maps.items()
}

num_kv_splits = 8

return TritonMLAMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand All @@ -645,7 +641,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
num_kv_splits=num_kv_splits,
num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(),
)

Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def extra_repr(self) -> str:
s += f", backend={self.impl.__class__.__name__}"
return s

def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading()
self.impl.process_weights_after_loading(act_dtype)


class MultiHeadAttention(nn.Module):
Expand Down
Loading