Skip to content

Commit

Permalink
Refactor quant param schema extraction + loading, address reviewer co…
Browse files Browse the repository at this point in the history
…mments, bring in NCCL fixes from upstream
  • Loading branch information
mawong-amd committed Apr 2, 2024
1 parent ef8d9fb commit 14d55ec
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 116 deletions.
4 changes: 2 additions & 2 deletions docs/source/quantization/fp8_e4m3_kvcache.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ Here is an example of how to enable this feature:
# https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own.
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=1.2, top_p=0.9)
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
# llm=LLM(model="meta-llama/Llama-2-7b-chat-hf", # if local:
llm = LLM(model="/data/models/llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
prompt = "Barcelona is the capital of"
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0]
print(out)
Expand Down
33 changes: 17 additions & 16 deletions examples/fp8/extract_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from safetensors.torch import safe_open

from vllm.model_executor.layers.quantization.schema import QuantParamSchema


# Adapted from vllm/model_executor/weight_utils.py
# The main differences are that we add the NPZ format and simplify
Expand Down Expand Up @@ -80,7 +82,7 @@ def _hf_tensorfile_iterator(filename: str, load_format: str,
yield name, param
elif use_safetensors:
with safe_open(filename, framework="pt") as f:
for name in f:
for name in f.keys(): # NOQA: SIM118
param = f.get_tensor(name)
yield name, param
else:
Expand Down Expand Up @@ -287,22 +289,21 @@ def main(args):
# Postprocess: formatting to the current schema. Consider pulling it
# out into a dedicated function should it ever become more complicated.
rank_scales_map = {
rank_keyword + str(rank): {k: scale[k]
for k in sorted(scale.keys())}
rank: {k: scale[k]
for k in sorted(scale.keys())}
for rank, scale in rank_scales_map.items()
}

# Consider generalizing and formalizing this into its own class
# (and other necessary subclasses) in the future
schema = { "model_type": recovered_metadata["model_type"],
"kv_cache": {
"dtype": "float8_e4m3fn" if len(rank_scales_map) > 0 \
else recovered_metadata["model_dtype"],
"scaling_factor": rank_scales_map
},
# TODO: Expand this with activation and weights scaling
# factors when they are used in the future
}
# TODO: Expand this with activation and weights scaling factors when
# they are used in the future
schema = QuantParamSchema(
model_type=recovered_metadata["model_type"],
kv_cache={
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
recovered_metadata["model_dtype"]),
"scaling_factor":
rank_scales_map
},
)

if args.output_dir is None:
output_file = os.path.join(args.quantized_model, args.output_name)
Expand All @@ -312,7 +313,7 @@ def main(args):
output_file = os.path.join(args.output_dir, args.output_name)

with open(output_file, 'w') as f:
json.dump(schema, f, indent=4)
f.write(schema.model_dump_json(indent=4))
print(f"Completed! KV cache scaling factors saved to {output_file}")


Expand Down
2 changes: 1 addition & 1 deletion tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"rank0": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
Expand Down
2 changes: 1 addition & 1 deletion tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"rank0": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _check_use_naive_attention() -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
has_flash_attn = importlib.util.find_spec("flash_attn") is None
has_flash_attn = importlib.util.find_spec("flash_attn") is not None
if not has_flash_attn:
logger.warning("flash_attn is not installed. Using naive attention. "
"This will take significantly more GPU memory.")
Expand Down
84 changes: 84 additions & 0 deletions vllm/model_executor/layers/quantization/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""

from typing import Dict, Optional

from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator


class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]

@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!")
return self

@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}.")
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}.")
for i in range(tp_size):
assert i in self.scaling_factor, (
f"KV cache scales map for TP rank {i} not found.")
return self

@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}.")
return self


class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema

@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!")
return self
3 changes: 1 addition & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def load_weights(self,
default_weight_loader)
weight_loader(param, loaded_weight)

# Should not be called unless the KV cache dtype is FP8 on ROCm (AMD GPU)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
Expand All @@ -437,5 +436,5 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
else:
raise RuntimeError("PagedAttention has no KV cache scaling "
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
4 changes: 2 additions & 2 deletions vllm/model_executor/parallel_utils/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
so_file = "libnccl.so"
elif torch.version.hip is not None:
so_file = "librccl.so.2"
so_file = "librccl.so"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")
Expand Down
102 changes: 13 additions & 89 deletions vllm/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
import json
import os
from collections import defaultdict
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Iterable, Iterator, List, Optional, Tuple

import filelock
import numpy as np
import torch
from huggingface_hub import HfFileSystem, snapshot_download
from pydantic import BaseModel, ConfigDict, computed_field, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema

logger = init_logger(__name__)

Expand Down Expand Up @@ -285,96 +285,20 @@ def kv_cache_scales_loader(
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of
3rdparty/quantization/extract_scales.pyc
Keep this function in sync with the output of examples/fp8/extract_scales.py
"""

# TODO: Once a quantization params format is finalized, pull these Pydantic
# schemas out into a separate file that deals solely with quantization
# params and their related schemas so that they can be generalized and
# shared across various use cases.
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[str, Dict[int, float]]

@computed_field
@property
def rank_keyword(self) -> str:
# Each TP rank key should be prefixed by a common rank_keyword.
# Thus, recovering the alphabetical components of any key should
# return it.
rank_keyword = "".join(
char for char in next(iter(self.scaling_factor.keys()))
if char.isalpha())
return rank_keyword

@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!")
return self

@model_validator(mode="after")
def check_tp_ranks(self) -> "KVCacheQuantSchema":
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}.")
for tp_rank, layer_maps in self.scaling_factor.items():
assert tp_rank.startswith(self.rank_keyword), (
f"TP `{tp_rank}` does not start with `{self.rank_keyword}`"
)
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP `{tp_rank}` is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}.")
for i in range(tp_size):
assert f"{self.rank_keyword}{i}" in self.scaling_factor, (
f"KV cache scales map for TP rank {i} not found.")
return self

@model_validator(mode="after")
def check_current_rank(self) -> "KVCacheQuantSchema":
layer_scales_map = self.scaling_factor[
f"{self.rank_keyword}{tp_rank}"]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}.")
return self

class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema

@model_validator(mode="after")
def check_model_type(self) -> "QuantParamSchema":
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!")
return self

try:
with open(filename) as f:
# Loading and processing the entire dictionary at once allows us
# to do sanity checks all at once and avoid a situation where we
# have to abort after having partially loaded scaling factors
# Since the number of layers is small and (for now) we use scalar
# scaling factors (so the size they use is also small), this is
# not a concern at present.
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct)
layer_scales_map = schema.kv_cache.scaling_factor[
f"{schema.kv_cache.rank_keyword}{tp_rank}"]
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()

except FileNotFoundError:
Expand All @@ -389,7 +313,7 @@ def check_model_type(self) -> "QuantParamSchema":
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
return ()
return []


def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
Expand Down
5 changes: 3 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)

Expand Down Expand Up @@ -120,7 +120,8 @@ def load_model(self) -> None:
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)

if self.kv_cache_dtype == "fp8":
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
Expand Down

0 comments on commit 14d55ec

Please sign in to comment.