Skip to content

Commit

Permalink
add kv cache remap method for quark format
Browse files Browse the repository at this point in the history
  • Loading branch information
kewang2 committed Nov 29, 2024
1 parent 3b9d6bd commit deedcc9
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 1 deletion.
12 changes: 12 additions & 0 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,18 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:

if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,18 @@ def load_weights(self, weights: Iterable[Tuple[str,
for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue

if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config

self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand Down Expand Up @@ -427,7 +428,18 @@ def load_weights(self, weights: Iterable[Tuple[str,
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue


if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
Expand Down Expand Up @@ -1430,6 +1431,16 @@ def load_weights(self, weights: Iterable[Tuple[str,
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.config = config
self.lora_config = lora_config
self.quant_config = quant_config

self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand Down Expand Up @@ -493,6 +494,16 @@ def load_weights(self, weights: Iterable[Tuple[str,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = vllm_config.quant_config

self.model = PhiMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand Down Expand Up @@ -622,6 +623,17 @@ def load_weights(self, weights: Iterable[Tuple[str,
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
))

self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -364,6 +365,16 @@ def load_weights(self, weights: Iterable[Tuple[str,
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ def load_weights(self, weights: Iterable[Tuple[str,
if (self.config.text_config.tie_word_embeddings
and "lm_head.weight" in name):
continue
if scale_names := self.quant_config.get_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
for scale_name in scale_names:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight if loaded_weight.dim()==0 else loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
Expand Down

0 comments on commit deedcc9

Please sign in to comment.