Skip to content

Commit

Permalink
clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jun 27, 2024
1 parent 4b9165d commit 90094db
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 149 deletions.
60 changes: 19 additions & 41 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader_new)
weight_loader=self.weight_loader_v2)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
Expand All @@ -243,7 +243,7 @@ def __init__(self,
else:
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data

Expand All @@ -252,7 +252,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(param.output_dim, start_idx,
shard_size)
elif param.use_col_shard_splitting:
elif param.use_col_shard_split:
param_data, loaded_weight = param.col_shard_splitter(
param_data, loaded_weight, 0)

Expand Down Expand Up @@ -335,11 +335,6 @@ def _default_loading(self, param: vLLMParameter, param_data, loaded_weight,
shard_size, shard_offset = param.adjust_packed_shard(
shard_offset=shard_offset, shard_size=shard_size)

if param.use_bits_and_bytes:
shard_size = loaded_weight.shape[param.output_dim]
shard_offset = loaded_weight.shape[param.output_dim] * \
loaded_shard_id

param_data = param_data.narrow(param.output_dim, shard_offset,
shard_size)
loaded_weight.narrow(param.output_dim, tp_rank * shard_size,
Expand All @@ -366,12 +361,12 @@ def _load_no_shard_id(self, param: vLLMParameter, loaded_weight):
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_new(param, loaded_weight_shard, shard_id)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_new(self,
param: vLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
def weight_loader_v2(self,
param: vLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
if loaded_shard_id is None:
if param.output_dim is None: # TODO: why?
Expand All @@ -393,7 +388,7 @@ def weight_loader_new(self,
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
elif param.use_col_shard_splitting:
elif param.use_col_shard_split:
param_data, loaded_weight = param.col_shard_splitter(
param_data=param_data,
loaded_weight=loaded_weight,
Expand Down Expand Up @@ -498,18 +493,6 @@ def _default_loading(self, param: vLLMParameter, param_data, loaded_weight,
shard_size, shard_offset = param.adjust_packed_shard(
shard_offset=shard_offset, shard_size=shard_size)

if param.use_bits_and_bytes:
total = self._get_shard_offset_mapping("total")

# TODO: do we ever have a case where bits and bytes and packed?
# If not, these are the same
orig_offset = self._get_shard_offset_mapping(loaded_shard_id)
orig_size = self._get_shard_size_mapping(loaded_shard_id)

quantized_total = param.data.shape[0]
shard_offset = orig_offset * quantized_total // total
shard_size = orig_size * quantized_total // total

param_data = param_data.narrow(param.output_dim, shard_offset,
shard_size)
shard_id = tp_rank if loaded_shard_id == "q" else tp_rank // self.num_kv_head_replicas
Expand Down Expand Up @@ -542,12 +525,12 @@ def _load_no_shard_id(self, param: vLLMParameter, loaded_weight):
loaded_weight_shard = loaded_weight.narrow(param.output_dim,
shard_offset,
shard_size)
self.weight_loader_new(param, loaded_weight_shard, shard_id)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)

def weight_loader_new(self,
param: vLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
def weight_loader_v2(self,
param: vLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):

param_data = param.data
if loaded_shard_id is None: # special case for certain models
Expand All @@ -571,7 +554,7 @@ def weight_loader_new(self,
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
elif param.use_col_shard_splitting:
elif param.use_col_shard_split:
param_data, loaded_weight = param.col_shard_splitter(
param_data=param_data,
loaded_weight=loaded_weight,
Expand Down Expand Up @@ -633,7 +616,7 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader_new)
weight_loader=self.weight_loader_v2)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
Expand All @@ -648,8 +631,8 @@ def __init__(self,
else:
self.register_parameter("bias", None)

def weight_loader_new(self, param: vLLMParameter,
loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: vLLMParameter,
loaded_weight: torch.Tensor):

param_data = param.data
tp_rank = get_tensor_model_parallel_rank()
Expand All @@ -659,12 +642,7 @@ def weight_loader_new(self, param: vLLMParameter,
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(param.input_dim, start_idx,
shard_size)
elif param.use_row_shard_splitting:
param_data, loaded_weight = param.row_shard_splitter(param_data,
loaded_weight,
shard_id=0)

if not param.use_row_shard_splitting and len(loaded_weight.shape) == 0:
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

assert param_data.shape == loaded_weight.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,28 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_dim=0,
output_dim=1,
packed_dim=1,
use_row_loading=True,
use_col_loading=True,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)

input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)

scales = vLLMParameter(data=torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
output_dim=1,
input_dim=None if input_groups == 1 else 0,
weight_loader=weight_loader)
input_dim = None if input_groups == 1 else 0

scales = vLLMParameter(
data=torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
output_dim=1,
input_dim=input_dim,
use_col_loading=True,
use_row_loading=True if input_dim is not None else False,
weight_loader=weight_loader)

weight_shape = vLLMParameter(data=torch.empty(2, dtype=torch.int64),
weight_loader=weight_loader)
Expand All @@ -71,6 +78,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
),
input_dim=0,
output_dim=1,
use_col_loading=True,
use_row_loading=True,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,7 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from functools import partial
from vllm.model_executor.parameter import vLLMParameter


class SomeParam(vLLMParameter):

def __init__(self, logical_widths, **kwargs):
self.func = partial(self._col_shard_splitte_with_weights,
logical_widths=logical_widths)
super().__init__(**kwargs)

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def _col_shard_splitte_with_weights(
self,
logical_widths: torch.Tensor,
param_data: torch.Tensor,
loaded_weight: torch.Tensor,
shard_id: Union[str, int],
) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param_data[offset:offset + size], loaded_weight

def col_shard_splitter(self, *args, **kwargs):
return self.func(*args, **kwargs)
from vllm.model_executor.parameter import vLLMParameter, ScalerToArrayvLLMParameter


class CompressedTensorsW8A8(CompressedTensorsScheme):
Expand Down Expand Up @@ -69,6 +34,8 @@ def create_weights(self, layer: torch.nn.Module,
dtype=torch.int8),
input_dim=1,
output_dim=0,
use_row_loading=True,
use_col_loading=True,
weight_loader=weight_loader)

# Don't need a shard_splitter for channel-wise quantization
Expand All @@ -77,12 +44,13 @@ def create_weights(self, layer: torch.nn.Module,
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = vLLMParameter(data=weight_scale_data,
output_dim=0,
use_col_loading=True,
weight_loader=weight_loader)
else:
weight_scale = SomeParam(data=weight_scale_data,
weight_loader=weight_loader,
logical_widths=output_partition_sizes)
weight_scale.use_col_shard_splitting = True
weight_scale = ScalerToArrayvLLMParameter(
data=weight_scale_data,
weight_loader=weight_loader,
logical_widths=output_partition_sizes)

layer.register_parameter("weight", weight)
layer.register_parameter("weight_scale", weight_scale)
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,

weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
use_row_loading=True,
use_col_loading=True,
weight_loader=weight_loader,
packed_factor=pack_factor,
packed_dim=1,
Expand All @@ -62,14 +64,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
dtype=torch.int32,
))

weight_scale = vLLMParameter(input_dim=weight_scale_dim,
output_dim=0,
weight_loader=weight_loader,
data=torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
))
weight_scale = vLLMParameter(
input_dim=weight_scale_dim,
output_dim=0,
use_col_loading=True,
use_row_loading=True if weight_scale_dim is not None else False,
weight_loader=weight_loader,
data=torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
))

# A 2D array defining the original shape of the weights
# before packing
Expand Down
Loading

0 comments on commit 90094db

Please sign in to comment.