Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] Deepseek v3 MLA support with FP8 compute #12601

Merged
merged 41 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27ad92c
squashed commits
LucasWilkinson Jan 30, 2025
c34e5ca
fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0
LucasWilkinson Jan 30, 2025
f2cac91
more cleanups
LucasWilkinson Jan 30, 2025
068e672
Update utils.py
LucasWilkinson Jan 30, 2025
31b802c
Update vllm/attention/backends/mla/utils.py
LucasWilkinson Jan 30, 2025
634eee6
review comments
LucasWilkinson Jan 30, 2025
7487429
renaming for consistency
LucasWilkinson Jan 30, 2025
d27826d
Update vllm/config.py
LucasWilkinson Jan 30, 2025
8bdc14a
review comments
LucasWilkinson Jan 30, 2025
09d814c
review comments
LucasWilkinson Jan 30, 2025
4a46014
Update vllm/attention/backends/mla/utils.py
LucasWilkinson Jan 30, 2025
0881475
disable MLA for v3 for now
LucasWilkinson Jan 30, 2025
37e39f4
fix failing test
LucasWilkinson Jan 30, 2025
cfb2d26
fix mypy
LucasWilkinson Jan 30, 2025
5afc1bf
fix mypy
LucasWilkinson Jan 30, 2025
54ba87d
add cuda graph support
LucasWilkinson Jan 30, 2025
31c34bf
ci fix
LucasWilkinson Jan 30, 2025
433322b
Revert "add cuda graph support"
LucasWilkinson Jan 31, 2025
f2b2500
Fix TP > 1 cuda graphs
LucasWilkinson Jan 31, 2025
2d61054
cleanup
LucasWilkinson Jan 31, 2025
645622c
cleanup
LucasWilkinson Jan 31, 2025
0ccbcce
deepseek v3 support
LucasWilkinson Jan 31, 2025
076cbe5
Merge branch 'main' into mla-fp8
LucasWilkinson Jan 31, 2025
a57cd3d
Merge branch 'main' of github.com:vllm-project/vllm into mla-fp8
simon-mo Jan 31, 2025
548ec44
simon changes
LucasWilkinson Jan 31, 2025
3d12a04
working but messy
LucasWilkinson Jan 31, 2025
f51cbe0
review comments
LucasWilkinson Jan 31, 2025
3cdd2ce
cleanup
LucasWilkinson Jan 31, 2025
c9d72cb
more cleanup
LucasWilkinson Jan 31, 2025
4251506
fixes
LucasWilkinson Jan 31, 2025
9829fae
misc
LucasWilkinson Jan 31, 2025
1621381
Update vllm/model_executor/model_loader/loader.py
LucasWilkinson Jan 31, 2025
e144da8
Update vllm/model_executor/model_loader/loader.py
LucasWilkinson Jan 31, 2025
db2c583
filter compressed tensor models better
LucasWilkinson Feb 1, 2025
fac827f
Merge remote-tracking branch 'origin/main' into mla-fp8
LucasWilkinson Feb 1, 2025
5002734
simplification
LucasWilkinson Feb 1, 2025
0d66687
Update loader.py
simon-mo Feb 1, 2025
5fe1d1d
format
LucasWilkinson Feb 1, 2025
5d5071c
reduce split kv amount
LucasWilkinson Feb 1, 2025
7ac6f52
fix none type error
LucasWilkinson Feb 1, 2025
dc0e2af
ci fix
LucasWilkinson Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 87 additions & 33 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
LinearBase, RowParallelLinear)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear, block_quantize, is_fp8, scaled_dequant)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

Expand All @@ -25,11 +29,11 @@ class MLACommonMetadata(AttentionMetadata):

class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
Common class for implementing repeated parts
Common class for implementing repeated parts

Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
Expand All @@ -46,7 +50,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q

#
# Outside the MLA attention backend
#
Expand All @@ -55,21 +59,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.

#
# Inside the MLA attention backend
#

* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).

3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
Expand Down Expand Up @@ -112,7 +116,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.

"""

def __init__(
Expand Down Expand Up @@ -162,15 +166,35 @@ def __init__(

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed(
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
if is_fp8(self.W_UV_O):
output_parallel = apply_w8a8_block_fp8_linear(
x.flatten(start_dim=1),
self.W_UV_O,
[128, 128],
self.W_UV_O_scales,
)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]

def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_w8a8_block_fp8_linear(
x,
self.W_Q_UK,
[128, 128],
self.W_Q_UK_scales,
).view(-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
Expand All @@ -180,7 +204,24 @@ def _q_proj_and_k_up_proj(self, x):
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self):
kv_b_proj_weight = self.kv_b_proj.weight.T

def get_and_maybe_dequant_weights(layer: LinearBase):
if isinstance(layer.quant_method, Fp8LinearMethod):
# TODO(lucas) support non block quantized
assert hasattr(layer, "weight_scale_inv") and \
layer.quant_method.block_quant is not None
return scaled_dequant(
layer.weight, layer.weight_scale_inv,
layer.quant_method.quant_config.weight_block_size)\
.to(torch.bfloat16)
else:
return layer.weight

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype

kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
Expand All @@ -198,15 +239,15 @@ def process_weights_after_loading(self):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)

q_proj = self.q_proj.weight.T\
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)

# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
Expand All @@ -223,25 +264,38 @@ def process_weights_after_loading(self):
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()

W_O = self.o_proj.weight\
if is_fp8(weight_dtype):
W_Q_UK, W_Q_UK_scales = block_quantize(W_Q_UK, (128, 128))
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK

W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()

tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it too onerous to construct a quant method here? (i.e. should we try to make this easier in the future?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya because you have other make a quant_config and stuff and all the weight names are different, and for per-channel we enter the compressed tensors world which is another can of worms :/

agreed we should try to make it easier in the future

)
if is_fp8(weight_dtype):
W_UV_O, W_UV_O_scales = block_quantize(W_UV_O, (128, 128))
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O

self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")

self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
Expand Down
74 changes: 54 additions & 20 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
import torch
import triton
import triton.language as tl

from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)


def is_fp8(x:
[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -57,39 +65,65 @@ def input_to_float8(
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


def scaled_dequant(
x_q: torch.Tensor,
x_s: torch.Tensor,
block_size: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if block_size is not None:
assert x_s.shape[-1] == x_q.shape[-1] // block_size[1]
assert x_s.shape[-2] == x_q.shape[-2] // block_size[0]

x_s = group_broadcast(x_s, x_q.shape)
return x_q.to(torch.float32) * x_s


def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = scaled_dequant(x_q_block, x_s)
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale

x_dq_block = x_q_block.to(torch.float32)

x_dq_block_tiles = [[
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)]
def block_quantize(
x: torch.Tensor,
block_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
):
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)

for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % block_size[0] == 0 and x.shape[1] % block_size[1] == 0
blk_m, blk_n = x.shape[0] // block_size[0], x.shape[1] // block_size[1]
x_blkd = x.reshape(blk_m, block_size[0], blk_n, block_size[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
# Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
.clamp(min=finfo.min, max=finfo.max)\
.reshape(blk_m, blk_n, block_size[0], block_size[1])\
.permute(0, 2, 1, 3)\
.reshape(x.shape)

x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


@triton.jit
Expand Down
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,27 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()

return pack_cols(q_w, num_bits, size_k, size_n)


# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
elif isinstance(module, Attention) and \
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
Expand Down Expand Up @@ -1272,7 +1272,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:

class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
Model loader that can load safetensors
files from local FS or S3 bucket.
"""

Expand Down
Loading
Loading