Skip to content

Commit

Permalink
deepseek v3 support
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent 645622c commit 0ccbcce
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 43 deletions.
94 changes: 74 additions & 20 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
LinearBase, RowParallelLinear)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear, block_quantize, is_fp8, scaled_dequant)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

Expand Down Expand Up @@ -162,15 +166,35 @@ def __init__(

def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed(
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
if is_fp8(self.W_UV_O):
output_parallel = apply_w8a8_block_fp8_linear(
x.flatten(start_dim=1),
self.W_UV_O,
[128, 128],
self.W_UV_O_scales,
)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]

def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_w8a8_block_fp8_linear(
x,
self.W_Q_UK,
[128, 128],
self.W_Q_UK_scales,
).view(-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
Expand All @@ -180,7 +204,24 @@ def _q_proj_and_k_up_proj(self, x):
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self):
kv_b_proj_weight = self.kv_b_proj.weight.T

def get_and_maybe_dequant_weights(layer: LinearBase):
if isinstance(layer.quant_method, Fp8LinearMethod):
# TODO(lucas) support non block quantized
assert hasattr(layer, "weight_scale_inv") and \
layer.quant_method.block_quant is not None
return scaled_dequant(
layer.weight, layer.weight_scale_inv,
layer.quant_method.quant_config.weight_block_size)\
.to(torch.bfloat16)
else:
return layer.weight

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype

kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
Expand All @@ -198,15 +239,15 @@ def process_weights_after_loading(self):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)

q_proj = self.q_proj.weight.T\
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)

# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
Expand All @@ -223,25 +264,38 @@ def process_weights_after_loading(self):
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()

W_O = self.o_proj.weight\
if is_fp8(weight_dtype):
W_Q_UK, W_Q_UK_scales = block_quantize(W_Q_UK, (128, 128))
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK

W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()

tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
)
if is_fp8(weight_dtype):
W_UV_O, W_UV_O_scales = block_quantize(W_UV_O, (128, 128))
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O

self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")

self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
Expand Down
74 changes: 54 additions & 20 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import triton
import triton.language as tl

from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.platforms import current_platform


def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz


def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -51,39 +59,65 @@ def input_to_float8(
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


def scaled_dequant(
x_q: torch.Tensor,
x_s: torch.Tensor,
block_size: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if block_size is not None:
assert x_s.shape[-1] == x_q.shape[-1] // block_size[1]
assert x_s.shape[-2] == x_q.shape[-2] // block_size[0]

x_s = group_broadcast(x_s, x_q.shape)
return x_q.to(torch.float32) * x_s


def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = scaled_dequant(x_q_block, x_s)
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale

x_dq_block = x_q_block.to(torch.float32)

x_dq_block_tiles = [[
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)]
def block_quantize(
x: torch.Tensor,
block_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
):
if dtype is None:
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype)

for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % block_size[0] == 0 and x.shape[1] % block_size[1] == 0
blk_m, blk_n = x.shape[0] // block_size[0], x.shape[1] // block_size[1]
x_blkd = x.reshape(blk_m, block_size[0], blk_n, block_size[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
# Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
.clamp(min=finfo.min, max=finfo.max)\
.reshape(blk_m, blk_n, block_size[0], block_size[1])\
.permute(0, 2, 1, 3)\
.reshape(x.shape)

x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


@triton.jit
Expand Down
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,27 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()

return pack_cols(q_w, num_bits, size_k, size_n)


# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
elif isinstance(module, Attention) and \
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
Expand Down
Loading

0 comments on commit 0ccbcce

Please sign in to comment.