From ecc7aa74ddcfaba89901532abd07e18f2f36bff4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 2 Aug 2024 13:16:09 +0800 Subject: [PATCH 1/2] refactor and decouple weight loading logic for internvl --- vllm/model_executor/models/intern_vit.py | 11 +++- vllm/model_executor/models/internvl.py | 79 ++++++++---------------- 2 files changed, 35 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index c6c692deca2e1..54c933e3e4959 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -4,7 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { 'rms_norm': RMSNorm, @@ -268,3 +269,11 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index eabc283b1efdb..c122d46759541 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,6 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch @@ -414,58 +415,28 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".w1", 0), - (".gate_up_proj", ".w3", 1), - ] - params_dict = dict(self.named_parameters()) + def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + pattern: str): for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.text_config.tie_word_embeddings \ - and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - if "wqkv" in name: - config = self.config.text_config - kv_groups = (config.num_attention_heads // - config.num_key_value_heads) - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, - head_dim, - loaded_weight.shape[-1]) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], - dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) - weight_loader = param.weight_loader - weight_loader(param, wq, 'q') - weight_loader(param, wk, 'k') - weight_loader(param, wv, 'v') - continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if pattern in name: + name = name.replace(pattern, "") + yield name, loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # load vision encoder + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + vit_weights = self._filter_weights(vit_weights, "vision_model.") + self.vision_model.load_weights(vit_weights) + + # load mlp projector + mlp_weights = self._filter_weights(mlp_weights, "mlp1.") + mlp_params_dict = dict(self.mlp1.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = self._filter_weights(llm_weights, "language_model.") + self.language_model.load_weights(llm_weights) From 80266a7457daff3b11f886e3ce655a45199215e5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 2 Aug 2024 21:46:59 +0800 Subject: [PATCH 2/2] rewrite weights filter --- vllm/model_executor/models/internvl.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index c122d46759541..4749251271487 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -416,20 +416,23 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], - pattern: str): + prefix: str): for name, loaded_weight in weights: - if pattern in name: - name = name.replace(pattern, "") + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) yield name, loaded_weight def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load vision encoder + # prepare weight iterators for components vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) - vit_weights = self._filter_weights(vit_weights, "vision_model.") + + # load vision encoder + vit_weights = self._filter_weights(vit_weights, "vision_model") self.vision_model.load_weights(vit_weights) # load mlp projector - mlp_weights = self._filter_weights(mlp_weights, "mlp1.") + mlp_weights = self._filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) for name, loaded_weight in mlp_weights: param = mlp_params_dict[name] @@ -438,5 +441,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = self._filter_weights(llm_weights, "language_model.") + llm_weights = self._filter_weights(llm_weights, "language_model") self.language_model.load_weights(llm_weights)