diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5190de65d7956..5e1d63a6a62eb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -423,7 +423,10 @@ def weight_loader(self, shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, shard_size) - shard_id = tp_rank // self.num_kv_head_replicas + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)