Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add BNB support to GLM4-V model #12184

Merged
merged 8 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,15 +1105,22 @@ def _load_weights(self, model_config: ModelConfig,
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos > 0) and (
quant_param_name[shard_pos - 1] == ".")
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(
shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) \
and (new_quant_param_name in param_dict)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = quant_param_name.replace(
shard_name, weight_name)
quant_param_name = new_quant_param_name
break

# Models like Clip/Siglip may skip some layers in initialization,
Expand Down
95 changes: 47 additions & 48 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.transformers_utils.configs import ChatGLMConfig

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -605,9 +605,50 @@ def forward(
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "rotary_pos_emb.inv_freq" in name:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""}, )

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -660,52 +701,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
"transformer.vision.linear_proj.gate_proj.weight": None,
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
}
}

params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
if name in merged_weight_dict:
assert merged_weight_dict[name] is None
merged_weight_dict[name] = loaded_weight
is_weight_to_be_merge = True
if is_weight_to_be_merge:
continue
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
param = params_dict[combined_name]
combined_weight = torch.cat(list(merged_weight_dict.values()),
dim=0)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


class ChatGLM(ChatGLMBaseModel):
Expand All @@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):


class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):

packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
Expand Down Expand Up @@ -777,7 +776,7 @@ def __new__(
) -> None:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
# Initialize LLM
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
torch.Tensor
Transformed tensor with shape (B, L, D)
"""
images = images.to(self.proj.weight.device)
images = images.to(device=self.proj.weight.device,
dtype=self.proj.weight.dtype)
x = self.proj(images)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
Expand Down
Loading