Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Support FP8 kv cache scales from compressed-tensors #6528

Merged
merged 7 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,10 @@ def test_compressed_tensors_fp8(vllm_runner):

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output


def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
23 changes: 11 additions & 12 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod


class Attention(nn.Module):
Expand Down Expand Up @@ -59,19 +59,18 @@ def __init__(
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod)
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if "fp8" in self.kv_cache_dtype:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The k/v_scale will then be converted back to
# self._kv_scale in a native float32 value after weight loading
self.quant_method = quant_method
self.quant_method.create_weights(self)
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized,
Expand All @@ -15,18 +15,23 @@
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform


class CompressedTensorsConfig(QuantizationConfig):

def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
quant_format: str):
def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):

self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme

def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
Expand All @@ -50,9 +55,12 @@ def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["CompressedTensorsLinearMethod"]:
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
return None

@classmethod
Expand Down Expand Up @@ -85,7 +93,8 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":

return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format)
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -309,3 +318,47 @@ def apply(self,
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)


class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""

def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)

@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return

type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")

if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")

strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")

is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,23 @@ def _find_first_match(value: str,
return None


def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM

:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None


def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
Expand Down
63 changes: 5 additions & 58 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -400,64 +401,10 @@ def apply(self,
topk_group=topk_group)


class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka k_scale and v_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")

def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()

if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")

# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

del layer.k_scale
del layer.v_scale
super().__init__(quant_config)
78 changes: 78 additions & 0 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.utils import print_warning_once


class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""

def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)

def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
v_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()

if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")

# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

del layer.k_scale
del layer.v_scale
10 changes: 10 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -467,6 +469,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
Loading