From f1358820410187b2582aeefbc0999f03e72abb3f Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Sat, 3 Aug 2024 13:36:14 +0800
Subject: [PATCH] [Model] Refactor and decouple weight loading logic for
 InternVL2 model (#7067)

---
 vllm/model_executor/models/intern_vit.py | 11 +++-
 vllm/model_executor/models/internvl.py   | 82 ++++++++----------------
 2 files changed, 38 insertions(+), 55 deletions(-)

diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py
index c6c692deca2e1..54c933e3e4959 100644
--- a/vllm/model_executor/models/intern_vit.py
+++ b/vllm/model_executor/models/intern_vit.py
@@ -4,7 +4,7 @@
 # Copyright (c) 2023 OpenGVLab
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
-from typing import Optional
+from typing import Iterable, Optional, Tuple
 
 import torch
 import torch.nn as nn
@@ -16,6 +16,7 @@
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                RowParallelLinear)
 from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 
 NORM2FN = {
     'rms_norm': RMSNorm,
@@ -268,3 +269,11 @@ def forward(
         encoder_outputs = self.encoder(inputs_embeds=hidden_states)
 
         return encoder_outputs
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index eabc283b1efdb..4749251271487 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -4,6 +4,7 @@
 # Copyright (c) 2023 OpenGVLab
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
+import itertools
 from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
@@ -414,58 +415,31 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        stacked_params_mapping = [
-            # (param_name, shard_name, shard_id)
-            (".qkv_proj", ".q_proj", "q"),
-            (".qkv_proj", ".k_proj", "k"),
-            (".qkv_proj", ".v_proj", "v"),
-            (".gate_up_proj", ".gate_proj", 0),
-            (".gate_up_proj", ".up_proj", 1),
-            (".gate_up_proj", ".w1", 0),
-            (".gate_up_proj", ".w3", 1),
-        ]
-        params_dict = dict(self.named_parameters())
+    def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
+                        prefix: str):
         for name, loaded_weight in weights:
-            if "rotary_emb.inv_freq" in name:
-                continue
-            if self.config.text_config.tie_word_embeddings \
-                and "lm_head.weight" in name:
-                continue
-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
-                # We only do sharding for language model
-                # and not vision model for now.
-                if "vision_embed_tokens" in name and self.vision_embed_tokens:
-                    continue
-                if weight_name not in name:
-                    continue
-                param = params_dict[name.replace(weight_name, param_name)]
-                weight_loader = param.weight_loader
-                weight_loader(param, loaded_weight, shard_id)
-                break
-            else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                param = params_dict[name]
-                if "wqkv" in name:
-                    config = self.config.text_config
-                    kv_groups = (config.num_attention_heads //
-                                 config.num_key_value_heads)
-                    head_dim = config.hidden_size // config.num_attention_heads
-                    loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
-                                                       head_dim,
-                                                       loaded_weight.shape[-1])
-                    wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
-                                             dim=1)
-                    wq = wq.reshape(-1, wq.shape[-1])
-                    wk = wk.reshape(-1, wk.shape[-1])
-                    wv = wv.reshape(-1, wv.shape[-1])
-                    weight_loader = param.weight_loader
-                    weight_loader(param, wq, 'q')
-                    weight_loader(param, wk, 'k')
-                    weight_loader(param, wv, 'v')
-                    continue
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+            name = name.split(".")
+            if prefix == name.pop(0):
+                name = ".".join(name)
+                yield name, loaded_weight
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        # prepare weight iterators for components
+        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+
+        # load vision encoder
+        vit_weights = self._filter_weights(vit_weights, "vision_model")
+        self.vision_model.load_weights(vit_weights)
+
+        # load mlp projector
+        mlp_weights = self._filter_weights(mlp_weights, "mlp1")
+        mlp_params_dict = dict(self.mlp1.named_parameters())
+        for name, loaded_weight in mlp_weights:
+            param = mlp_params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load llm backbone
+        llm_weights = self._filter_weights(llm_weights, "language_model")
+        self.language_model.load_weights(llm_weights)