From deedcc95b3138c41d6ef30cea6ecf7847e64e1b7 Mon Sep 17 00:00:00 2001 From: kewang2 Date: Tue, 26 Nov 2024 04:58:48 -0700 Subject: [PATCH] add kv cache remap method for quark format --- vllm/model_executor/models/commandr.py | 12 ++++++++++++ vllm/model_executor/models/gpt_j.py | 12 ++++++++++++ vllm/model_executor/models/mixtral.py | 14 +++++++++++++- vllm/model_executor/models/mllama.py | 11 +++++++++++ vllm/model_executor/models/nemotron.py | 11 +++++++++++ vllm/model_executor/models/phimoe.py | 12 ++++++++++++ vllm/model_executor/models/qwen2.py | 11 +++++++++++ vllm/model_executor/models/qwen2_audio.py | 10 ++++++++++ 8 files changed, 92 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 85e24ca660686..c919d43faa0a5 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -419,6 +419,18 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: + + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4829578a56959..8764a6b53bc70 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -313,6 +313,18 @@ def load_weights(self, weights: Iterable[Tuple[str, for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue + + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a5b364fe5ec85..de43139f18be0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -347,6 +347,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config + self.quant_config = quant_config self.model = MixtralModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -427,7 +428,18 @@ def load_weights(self, weights: Iterable[Tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - + + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6536f9807730c..d024ac218f128 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1117,6 +1117,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + self.quant_config = quant_config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles @@ -1430,6 +1431,16 @@ def load_weights(self, weights: Iterable[Tuple[str, name = name.replace('patch_embedding.weight', 'patch_embedding._linear.weight') loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + updated_params.add(scale_name) + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index c7b4c22b6896b..ead0f3005d6e6 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -411,6 +411,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config + self.quant_config = quant_config self.model = NemotronModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -493,6 +494,16 @@ def load_weights(self, weights: Iterable[Tuple[str, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 1febd62f2f705..fe197999367a7 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -546,6 +546,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config + self.quant_config = vllm_config.quant_config self.model = PhiMoEModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -622,6 +623,17 @@ def load_weights(self, weights: Iterable[Tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 87943e53d861c..a153c96afc067 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -279,6 +279,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.config = config + self.quant_config = quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -364,6 +365,16 @@ def load_weights(self, weights: Iterable[Tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index a0605fee82aca..cd9b7e2b18a9f 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -454,6 +454,16 @@ def load_weights(self, weights: Iterable[Tuple[str, if (self.config.text_config.tie_word_embeddings and "lm_head.weight" in name): continue + if scale_names := self.quant_config.get_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + for scale_name in scale_names: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key)