-
-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] Deepseek v3 MLA support with FP8 compute #12601
Changes from all commits
27ad92c
c34e5ca
f2cac91
068e672
31b802c
634eee6
7487429
d27826d
8bdc14a
09d814c
4a46014
0881475
37e39f4
cfb2d26
5afc1bf
54ba87d
31c34bf
433322b
f2b2500
2d61054
645622c
0ccbcce
076cbe5
a57cd3d
548ec44
3d12a04
f51cbe0
3cdd2ce
c9d72cb
4251506
9829fae
1621381
e144da8
db2c583
fac827f
5002734
0d66687
5fe1d1d
5d5071c
7ac6f52
dc0e2af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,29 @@ | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, Generic, List, Optional | ||
from typing import Any, Dict, Generic, List, Optional, Tuple | ||
|
||
import torch | ||
from compressed_tensors.quantization import QuantizationStrategy | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm import envs | ||
from vllm.attention.backends.abstract import (AttentionLayer, | ||
AttentionMetadata, | ||
MLAAttentionImpl, T) | ||
from vllm.distributed import get_tensor_model_parallel_world_size | ||
from vllm.distributed import (get_tensor_model_parallel_world_size, | ||
tensor_model_parallel_all_reduce) | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
LinearBase, RowParallelLinear, | ||
UnquantizedLinearMethod) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 | ||
CompressedTensorsLinearMethod) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsW8A8Fp8) | ||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod | ||
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) | ||
from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
scaled_dequantize, scaled_quantize) | ||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding | ||
from vllm.vllm_flash_attn import flash_attn_varlen_func | ||
|
||
|
@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata): | |
|
||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): | ||
""" | ||
Common class for implementing repeated parts | ||
Common class for implementing repeated parts | ||
|
||
Main reference: DeepseekV2 paper, and FlashInfer Implementation | ||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). | ||
|
||
Deepseek's MLA attention works the following way: | ||
* Use a single latent vector to represent the entire KV cache. | ||
* The attention "simulates" a multi-head attention, while the compute is | ||
|
@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): | |
* V: V head dim. | ||
* kv_c: latent/compressed KV | ||
* q_c: latent/compressed Q | ||
|
||
# | ||
# Outside the MLA attention backend | ||
# | ||
|
@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): | |
kv_c_k_pe (B, Lkv+R). | ||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq | ||
and kv_c are normalized. | ||
|
||
# | ||
# Inside the MLA attention backend | ||
# | ||
|
||
* if prefill: | ||
3. The q_c is then projected up into the multi-head version. | ||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope | ||
(B, N, P) and q_pe (B, N, R). | ||
|
||
3. The q_c is then projected up into the multi-head version. | ||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope | ||
(B, N, P) and q_pe (B, N, R). | ||
4. q_pe, k_pe are then passed through rotary embeddings. | ||
5. kv_c and k_pe are concatenated and inserted into the cache | ||
6. The kv_c is then projected up into the multi-head version. | ||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope | ||
dimensions for K and V, which is split into k_nope (B, N, P) | ||
6. The kv_c is then projected up into the multi-head version. | ||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope | ||
dimensions for K and V, which is split into k_nope (B, N, P) | ||
and v (B, N, V). | ||
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from | ||
q_nope, q_pe, k_nope, k_pe. | ||
|
@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): | |
From @tsu-bin's calculation, we only want to use the absorption technique | ||
for decode. The prefill algorithm should still use the up-projected MHA | ||
for less flops and memory usage. | ||
|
||
""" | ||
|
||
def __init__( | ||
|
@@ -162,15 +174,32 @@ def __init__( | |
|
||
def _v_up_proj_and_o_proj(self, x): | ||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: | ||
return self.o_proj_absorbed( | ||
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0] | ||
if is_fp8(self.W_UV_O): | ||
output_parallel = apply_fp8_linear_generic( | ||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, | ||
self.reqaunt_input_group_shape, | ||
self.reqaunt_weight_group_shape) | ||
else: | ||
output_parallel = torch.matmul(x.flatten(start_dim=1), | ||
self.W_UV_O) | ||
if self.tp_size > 1: | ||
output = tensor_model_parallel_all_reduce(output_parallel) | ||
else: | ||
output = output_parallel | ||
return output | ||
else: | ||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) | ||
return self.o_proj(x.reshape(-1, | ||
self.num_heads * self.v_head_dim))[0] | ||
|
||
def _q_proj_and_k_up_proj(self, x): | ||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: | ||
if is_fp8(self.W_Q_UK): | ||
return apply_fp8_linear_generic( | ||
x, self.W_Q_UK, self.W_Q_UK_scales, | ||
self.reqaunt_input_group_shape, | ||
self.reqaunt_weight_group_shape).view( | ||
-1, self.num_heads, self.kv_lora_rank) | ||
return torch.matmul(x, self.W_Q_UK)\ | ||
.view(-1, self.num_heads, self.kv_lora_rank) | ||
else: | ||
|
@@ -179,8 +208,91 @@ def _q_proj_and_k_up_proj(self, x): | |
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ | ||
.view(-1, self.num_heads, self.kv_lora_rank) | ||
|
||
def process_weights_after_loading(self): | ||
kv_b_proj_weight = self.kv_b_proj.weight.T | ||
def process_weights_after_loading(self, act_dtype: torch.dtype): | ||
|
||
def is_layer_fp8(layer: LinearBase) -> bool: | ||
return isinstance(layer.quant_method, Fp8LinearMethod) or\ | ||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ | ||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) | ||
|
||
def quantization_scheme_supported(layer: LinearBase) -> bool: | ||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ | ||
is_layer_fp8(layer) | ||
|
||
# TODO(lucas) This is very gross, we need a more wide scale refactor of | ||
# all the FP8 code with a more standard way of | ||
# defining schemes/group-shapes, we should also potentially force | ||
# quant_methods to support a decompress function | ||
# | ||
# returns input_group_shape, weight_group_shape | ||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ | ||
Tuple[Tuple[int, int], Tuple[int, int]]: | ||
if isinstance(layer.quant_method, Fp8LinearMethod): | ||
if layer.quant_method.block_quant is not None: | ||
weight_block_size = \ | ||
layer.quant_method.quant_config.weight_block_size | ||
# per-token-group (1, X), block-quantized (X, Y) | ||
return (1, weight_block_size[-1]), weight_block_size | ||
else: | ||
return (-1, -1), (-1, -1) # per-tensor, per-tensor | ||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ | ||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): | ||
# this is hacky but we always assume the for | ||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token | ||
# we ignore if it is static-per-tensor since we are going to | ||
# requantize after later anyways | ||
strategy = layer.scheme.strategy | ||
if strategy == QuantizationStrategy.TENSOR: | ||
return (1, -1), (-1, -1) # per-token, per-tensor | ||
elif strategy == QuantizationStrategy.CHANNEL: | ||
return (1, -1), (-1, 1) # per-token, per-channel | ||
else: | ||
raise NotImplementedError( | ||
f"QuantizationStrategy.{strategy} is not supported for " | ||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1") | ||
else: | ||
raise NotImplementedError( | ||
"Can't determine scale group shapes for " | ||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" | ||
) | ||
|
||
def get_scales(layer: LinearBase) -> torch.Tensor: | ||
if hasattr(layer, "weight_scale_inv"): | ||
return layer.weight_scale_inv | ||
return layer.weight_scale | ||
|
||
def get_and_maybe_dequant_weights(layer: LinearBase): | ||
if is_layer_fp8(layer): | ||
if isinstance(layer.quant_method, \ | ||
CompressedTensorsLinearMethod) and \ | ||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8): | ||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8` | ||
# seems to store weights as (input, output) instead of | ||
# (output, input) so we need to transpose | ||
weight = layer.weight.T # standardize to (output, input) | ||
else: | ||
weight = layer.weight | ||
_, weight_scale_group_shape = \ | ||
get_scale_group_shapes_for_fp8(layer) | ||
scales = get_scales(layer) | ||
|
||
return scaled_dequantize(weight, scales, | ||
weight_scale_group_shape) | ||
else: | ||
return layer.weight | ||
|
||
if not (quantization_scheme_supported(self.kv_b_proj) and\ | ||
quantization_scheme_supported(self.q_proj) and\ | ||
quantization_scheme_supported(self.o_proj)): | ||
raise NotImplementedError( | ||
"Only FP8 and UnquantizedLinearMethod are supported for MLA" | ||
", please run with VLLM_MLA_DISABLE=1") | ||
|
||
weight_dtype = self.kv_b_proj.weight.dtype | ||
assert self.o_proj.weight.dtype == weight_dtype | ||
assert self.q_proj.weight.dtype == weight_dtype | ||
|
||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T | ||
assert kv_b_proj_weight.shape == ( | ||
self.kv_lora_rank, | ||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( | ||
|
@@ -198,18 +310,35 @@ def process_weights_after_loading(self): | |
W_UK, W_UV = kv_b_proj_weight.split( | ||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
|
||
q_proj = self.q_proj.weight.T\ | ||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ | ||
.view(-1, self.num_heads, self.qk_head_dim) | ||
|
||
# can be W_Q or W_UQ depending q_lora_rank, the former if | ||
# q_lora_rank is None, the latter otherwise. From the Attention backend | ||
# perspective though we call these both W_Q and rely on the layer | ||
# to pass in the correct matrix | ||
W_Q = q_proj[..., :self.qk_nope_head_dim] | ||
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\ | ||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim] | ||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ | ||
.flatten(start_dim=1).contiguous() | ||
|
||
# W_QR is small so for simplicity we dont bother requantizing it | ||
self.W_QR = self.W_QR.to(act_dtype) | ||
|
||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: | ||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION | ||
if is_fp8(weight_dtype) and requantization_enabled: | ||
# This assumes it wise to requantize using the same group shapes | ||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the | ||
# weights were originally quantized | ||
requant_input_group_shape, requant_weight_group_shape = \ | ||
get_scale_group_shapes_for_fp8(self.q_proj) | ||
assert (requant_input_group_shape, requant_weight_group_shape)\ | ||
== get_scale_group_shapes_for_fp8(self.kv_b_proj) | ||
assert (requant_input_group_shape, requant_weight_group_shape)\ | ||
== get_scale_group_shapes_for_fp8(self.o_proj) | ||
self.reqaunt_input_group_shape = requant_input_group_shape | ||
self.reqaunt_weight_group_shape = requant_weight_group_shape | ||
|
||
# | ||
# Perform matrix-absorption following | ||
# https://github.com/flashinfer-ai/flashinfer/pull/551 | ||
|
@@ -223,25 +352,44 @@ def process_weights_after_loading(self): | |
# latter otherwise | ||
# basically if q_lora_rank is none we are absorbing into q_proj | ||
# instead of UQ | ||
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ | ||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ | ||
.flatten(start_dim=1).contiguous() | ||
|
||
W_O = self.o_proj.weight\ | ||
if is_fp8(weight_dtype) and requantization_enabled: | ||
W_Q_UK, W_Q_UK_scales = scaled_quantize( | ||
W_Q_UK, | ||
self.reqaunt_weight_group_shape, | ||
quant_dtype=current_platform_fp8_dtype) | ||
# For FP8 save the transpose so we can use | ||
# `apply_w8a8_block_fp8_linear` directly | ||
self.W_Q_UK = W_Q_UK.T.contiguous() | ||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() | ||
else: | ||
self.W_Q_UK = W_Q_UK.to(act_dtype) | ||
|
||
W_O = get_and_maybe_dequant_weights(self.o_proj)\ | ||
.view(-1, self.num_heads, self.v_head_dim) | ||
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ | ||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ | ||
.flatten(start_dim=0, end_dim=1).contiguous() | ||
|
||
tp_size = get_tensor_model_parallel_world_size() | ||
self.o_proj_absorbed = RowParallelLinear( | ||
self.W_UV_O.shape[0] * tp_size, | ||
self.W_UV_O.shape[1], | ||
bias=False, | ||
# TODO(lucas) figure out how to properly forward quant_method | ||
#quant_config=self.o_proj.quant_method, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it too onerous to construct a quant method here? (i.e. should we try to make this easier in the future?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ya because you have other make a quant_config and stuff and all the weight names are different, and for per-channel we enter the compressed tensors world which is another can of worms :/ agreed we should try to make it easier in the future |
||
) | ||
|
||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T) | ||
if is_fp8(weight_dtype) and requantization_enabled: | ||
W_UV_O, W_UV_O_scales = scaled_quantize( | ||
W_UV_O, | ||
self.reqaunt_weight_group_shape, | ||
quant_dtype=current_platform_fp8_dtype) | ||
# For FP8 save the transpose so we can use | ||
# `apply_w8a8_block_fp8_linear` directly | ||
self.W_UV_O = W_UV_O.T.contiguous() | ||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous() | ||
else: | ||
self.W_UV_O = W_UV_O.to(act_dtype) | ||
|
||
self.tp_size = get_tensor_model_parallel_world_size() | ||
else: | ||
if is_fp8(weight_dtype): | ||
raise NotImplementedError( | ||
"Currently fp8 requires matrix absorption") | ||
|
||
self.W_UV = W_UV | ||
self.W_UK = W_UK | ||
self.W_Q = W_Q.flatten(start_dim=1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fp8LinearMethod.block_quant is a boolean, is there meant to be a check for False instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is a bug, I fixed it here #13181