From 6e712269328bb6928bb30bf067155e60b75048e0 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 27 Jun 2024 21:00:20 +0000 Subject: [PATCH] support both weight loading methods --- vllm/model_executor/__init__.py | 2 +- vllm/model_executor/layers/linear.py | 363 +++++++++++++++++- .../compressed_tensors/__init__.py | 1 + .../compressed_tensors/compressed_tensors.py | 2 + .../schemes/compressed_tensors_w4a16_24.py | 6 +- .../schemes/compressed_tensors_w8a8.py | 3 +- .../schemes/compressed_tensors_wNa16.py | 6 +- vllm/model_executor/parameter.py | 6 +- 8 files changed, 373 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index c0768bb2dbeae..5d2698b4b9518 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,6 +1,6 @@ +from vllm.model_executor.parameter import * from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.model_executor.parameter import * __all__ = [ "SamplingMetadata", diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 159e7cdf7dbd1..f52faae442831 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -10,11 +10,39 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors import ( + CompressedTensorsLinearMethod) from vllm.model_executor.parameter import vLLMParameter from vllm.model_executor.utils import set_weight_attrs +logger = init_logger(__name__) + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -224,6 +252,7 @@ def __init__(self, if output_sizes is None: output_sizes = [output_size] + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size, @@ -231,7 +260,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader_v2) + weight_loader=(self.weight_loader_v2 if isinstance( + self.quant_method, CompressedTensorsLinearMethod) else + self.weight_loader)) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -243,6 +274,28 @@ def __init__(self, else: self.register_parameter("bias", None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() param_data = param.data @@ -322,8 +375,126 @@ def __init__(self, params_dtype=params_dtype, quant_config=quant_config) - def _default_loading(self, param: vLLMParameter, param_data, loaded_weight, - loaded_shard_id): + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _default_loading(self, param: vLLMParameter, param_data: torch.Tensor, + loaded_weight: torch.Tensor, loaded_shard_id: int): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() @@ -481,8 +652,8 @@ def _get_shard_size_mapping(self, loaded_shard_id: str): } return shard_size_mapping.get(loaded_shard_id) - def _default_loading(self, param: vLLMParameter, param_data, loaded_weight, - loaded_shard_id): + def _default_loading(self, param: vLLMParameter, param_data: torch.Tensor, + loaded_weight: torch.Tensor, loaded_shard_id: str): tp_rank = get_tensor_model_parallel_rank() shard_offset = self._get_shard_offset_mapping(loaded_shard_id) @@ -495,7 +666,11 @@ def _default_loading(self, param: vLLMParameter, param_data, loaded_weight, param_data = param_data.narrow(param.output_dim, shard_offset, shard_size) - shard_id = tp_rank if loaded_shard_id == "q" else tp_rank // self.num_kv_head_replicas + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + loaded_weight.narrow(param.output_dim, shard_id * shard_size, shard_size) @@ -563,6 +738,150 @@ def weight_loader_v2(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id) + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + class RowParallelLinear(LinearBase): """Linear layer with row parallelism. @@ -616,7 +935,9 @@ def __init__(self, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, - weight_loader=self.weight_loader_v2) + weight_loader=(self.weight_loader_v2 if isinstance( + self.quant_method, CompressedTensorsLinearMethod) else + self.weight_loader)) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -631,6 +952,32 @@ def __init__(self, else: self.register_parameter("bias", None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + def weight_loader_v2(self, param: vLLMParameter, loaded_weight: torch.Tensor): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py index e69de29bb2d1d..93758e58b6579 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py @@ -0,0 +1 @@ +from .compressed_tensors import CompressedTensorsLinearMethod # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index c69e2f3bcf9fa..1c2d441cab7c2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -15,6 +15,8 @@ CompressionFormat, QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) +__all__ = ["CompressedTensorsLinearMethod"] + class CompressedTensorsConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index bc65cdc450e98..e45e5751c4946 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -2,12 +2,13 @@ import torch from torch.nn import Parameter + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) -from vllm.model_executor.parameter import vLLMParameter, PackedvLLMParameter +from vllm.model_executor.parameter import PackedvLLMParameter, vLLMParameter __all__ = ["CompressedTensorsW4A16Sparse24"] W4A16SPARSE24_SUPPORTED_BITS = [4] @@ -65,7 +66,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_dim=1, input_dim=input_dim, use_col_loading=True, - use_row_loading=True if input_dim is not None else False, + use_row_loading=True #noqa: SIM210 + if input_dim is not None else False, weight_loader=weight_loader) weight_shape = vLLMParameter(data=torch.empty(2, dtype=torch.int64), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py index fb8e4e35df4a4..a1cc86060dd27 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py @@ -6,7 +6,8 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) -from vllm.model_executor.parameter import vLLMParameter, ScalerToArrayvLLMParameter +from vllm.model_executor.parameter import (ScalerToArrayvLLMParameter, + vLLMParameter) class CompressedTensorsW8A8(CompressedTensorsScheme): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index e1498adba86c0..e8de70856fbe3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -2,13 +2,14 @@ import torch from torch.nn import Parameter + from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState, marlin_permute_scales) -from vllm.model_executor.parameter import vLLMParameter, PackedvLLMParameter +from vllm.model_executor.parameter import PackedvLLMParameter, vLLMParameter __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_BITS = [4, 8] @@ -68,7 +69,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, input_dim=weight_scale_dim, output_dim=0, use_col_loading=True, - use_row_loading=True if weight_scale_dim is not None else False, + use_row_loading=True # noqa: SIM210 + if weight_scale_dim is not None else False, weight_loader=weight_loader, data=torch.empty( output_size_per_partition, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 20dd499a68f87..36791caefe718 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,6 +1,8 @@ -from torch.nn import Parameter -from typing import Optional, Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union + import torch +from torch.nn import Parameter + from vllm.logger import init_logger __all__ = [