From dd2a6a82e3f41b4673b1dbb24b2e99230ea96981 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Mon, 2 Sep 2024 23:48:56 +0800
Subject: [PATCH] [Bugfix] Fix internlm2 tensor parallel inference (#8055)

---
 vllm/model_executor/models/internlm2.py | 47 ++++++++++++++++++-------
 1 file changed, 34 insertions(+), 13 deletions(-)

diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 9b7cada187ce1..23669b540f561 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
+from functools import partial
 from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 import torch
@@ -7,7 +8,10 @@
 
 from vllm.attention import Attention, AttentionMetadata
 from vllm.config import CacheConfig
-from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.distributed import (get_tensor_model_parallel_rank,
+                              get_tensor_model_parallel_world_size,
+                              split_tensor_along_last_dim,
+                              tensor_model_parallel_all_gather)
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -70,20 +74,21 @@ def __init__(
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
-        tp_size = get_tensor_model_parallel_world_size()
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.tp_rank = get_tensor_model_parallel_rank()
         self.total_num_heads = num_heads
-        assert self.total_num_heads % tp_size == 0
-        self.num_heads = self.total_num_heads // tp_size
+        assert self.total_num_heads % self.tp_size == 0
+        self.num_heads = self.total_num_heads // self.tp_size
         self.total_num_kv_heads = num_kv_heads
-        if self.total_num_kv_heads >= tp_size:
+        if self.total_num_kv_heads >= self.tp_size:
             # Number of KV heads is greater than TP size, so we partition
             # the KV heads across multiple tensor parallel GPUs.
-            assert self.total_num_kv_heads % tp_size == 0
+            assert self.total_num_kv_heads % self.tp_size == 0
         else:
             # Number of KV heads is less than TP size, so we replicate
             # the KV heads across multiple tensor parallel GPUs.
-            assert tp_size % self.total_num_kv_heads == 0
-        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+            assert self.tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
         self.head_dim = hidden_size // self.total_num_heads
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
@@ -122,11 +127,27 @@ def __init__(
                               quant_config=quant_config)
 
     def split_qkv(self, qkv: torch.Tensor):
-        qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
-        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
-        q = q.reshape(-1, self.q_size)
-        k = k.reshape(-1, self.kv_size)
-        v = v.reshape(-1, self.kv_size)
+        seq_len = qkv.shape[0]
+        if self.tp_size > 1:
+            qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
+            qkv = tensor_model_parallel_all_gather(qkv)
+            qkv = torch.split(qkv, qkv_map, dim=-1)
+            qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
+            qkv = torch.cat(qkv, dim=-1)
+
+        qkv = qkv.view(seq_len, self.total_num_kv_heads,
+                       self.key_value_groups + 2, self.head_dim)
+        q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
+        q = q.reshape(seq_len, self.q_size * self.tp_size)
+        k = k.reshape(seq_len, self.kv_size * self.tp_size)
+        v = v.reshape(seq_len, self.kv_size * self.tp_size)
+
+        if self.tp_size > 1:
+            splitter = partial(split_tensor_along_last_dim,
+                               num_partitions=self.tp_size)
+            q = splitter(q)[self.tp_rank]
+            k = splitter(k)[self.tp_rank]
+            v = splitter(v)[self.tp_rank]
         return q, k, v
 
     def forward(