Skip to content

Commit

Permalink
working but messy
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson committed Jan 31, 2025
1 parent 548ec44 commit 3d12a04
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 51 deletions.
147 changes: 117 additions & 30 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional
from typing import Any, Dict, Generic, List, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStrategy

from vllm import _custom_ops as ops
from vllm import envs
Expand All @@ -12,10 +13,15 @@
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear)
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear, block_quantize, is_fp8, scaled_dequant)
apply_fp8_linear_generic, fp8_quantize, is_fp8, scaled_dequant)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func

Expand Down Expand Up @@ -167,12 +173,10 @@ def __init__(
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
output_parallel = apply_w8a8_block_fp8_linear(
x.flatten(start_dim=1),
self.W_UV_O,
[128, 128],
self.W_UV_O_scales,
)
output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
Expand All @@ -189,12 +193,11 @@ def _v_up_proj_and_o_proj(self, x):
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_w8a8_block_fp8_linear(
x,
self.W_Q_UK,
[128, 128],
self.W_Q_UK_scales,
).view(-1, self.num_heads, self.kv_lora_rank)
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
Expand All @@ -203,20 +206,85 @@ def _q_proj_and_k_up_proj(self, x):
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)

def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: torch.dtype):

def get_and_maybe_dequant_weights(layer: LinearBase):
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))

def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)

# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
# TODO(lucas) support non block quantized
assert hasattr(layer, "weight_scale_inv") and \
layer.quant_method.block_quant is not None
return scaled_dequant(
layer.weight, layer.weight_scale_inv,
layer.quant_method.quant_config.weight_block_size)\
.to(torch.bfloat16)
if layer.quant_method.block_quant is not None:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)

def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale

def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)

return scaled_dequant(weight, scales, weight_scale_group_shape)
else:
return layer.weight

if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")

weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
Expand Down Expand Up @@ -250,7 +318,24 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()

# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)

if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape

#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
Expand All @@ -267,28 +352,30 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()

if is_fp8(weight_dtype):
W_Q_UK, W_Q_UK_scales = block_quantize(W_Q_UK, (128, 128))
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = fp8_quantize(
W_Q_UK, self.reqaunt_weight_group_shape)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK
self.W_Q_UK = W_Q_UK.to(act_dtype)

W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()

if is_fp8(weight_dtype):
W_UV_O, W_UV_O_scales = block_quantize(W_UV_O, (128, 128))
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = fp8_quantize(
W_UV_O, self.reqaunt_weight_group_shape)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O
self.W_UV_O = W_UV_O.to(act_dtype)

self.tp_size = get_tensor_model_parallel_world_size()
else:
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def extra_repr(self) -> str:
s += f", backend={self.impl.__class__.__name__}"
return s

def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading()
self.impl.process_weights_after_loading(act_dtype)


class MultiHeadAttention(nn.Module):
Expand Down
12 changes: 11 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -519,7 +520,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),

# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
}

# end-env-vars-definition
Expand Down
Loading

0 comments on commit 3d12a04

Please sign in to comment.