Skip to content

Commit

Permalink
[Model] Refactor and decouple weight loading logic for InternVL2 model (
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and sfc-gh-mkeralapura committed Aug 12, 2024
1 parent b6e9401 commit f135882
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 55 deletions.
11 changes: 10 additions & 1 deletion vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -16,6 +16,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

NORM2FN = {
'rms_norm': RMSNorm,
Expand Down Expand Up @@ -268,3 +269,11 @@ def forward(
encoder_outputs = self.encoder(inputs_embeds=hidden_states)

return encoder_outputs

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
82 changes: 28 additions & 54 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
Expand Down Expand Up @@ -414,58 +415,31 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".w1", 0),
(".gate_up_proj", ".w3", 1),
]
params_dict = dict(self.named_parameters())
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.text_config.tie_word_embeddings \
and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# We only do sharding for language model
# and not vision model for now.
if "vision_embed_tokens" in name and self.vision_embed_tokens:
continue
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config.text_config
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

# load vision encoder
vit_weights = self._filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)

# load mlp projector
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = self._filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)

0 comments on commit f135882

Please sign in to comment.