From 2299126718b6ba52292006118b5ec40aca7e3b2d Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Fri, 17 Jan 2025 13:14:17 +0000 Subject: [PATCH] [Fix]: Use Custom Tensor for dyn_quant_matmul_4bit aten op Signed-off-by: Nikhil Gupta --- torchao/dtypes/affine_quantized_tensor.py | 7 +- ...8_dynamic_activation_intx_weight_layout.py | 81 +++++++++++++++++++ torchao/experimental/quant_api.py | 44 +++++----- 3 files changed, 105 insertions(+), 27 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 63946ce290..e7aca34c5f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -206,7 +206,6 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, - bias: Optional[torch.Tensor] = None ): original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -279,11 +278,7 @@ def from_hp_to_intx( data = _layout.post_process(data) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - args = [data, scale, zero_point, _layout] - # Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias - if bias is not None: - args.append(bias) - tensor_impl = tensor_impl_ctr(*args) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) return cls( tensor_impl, block_size, diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index fd1d4b7c69..9d42596793 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -12,6 +12,8 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -20,7 +22,11 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, + MappingType, + choose_qparams_affine, + quantize_affine, ) + from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, ) @@ -325,3 +331,78 @@ def _impl_2d_aten(input_tensor, weight_tensor): _linear_check, _linear_impl, ) + + +class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): + """ + PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. + """ + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), + use_hqq: bool = False, + bias: Optional[torch.Tensor] = None + ): + assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" + assert isinstance( + _layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" + assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None + # TODO should probably consolidate ZeroPointDomain.NONE and None + if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + zero_point = None + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + # Note: output will be uint8 tensor for sub byte tensors for now + + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + +to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 57b2f66089..8c63874dc0 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -494,6 +494,7 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, + to_packedlinearint8dynamicactivationintxweight_quantized_intx, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -576,6 +577,7 @@ def int8_dynamic_activation_intx_weight( ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): @@ -591,6 +593,10 @@ def apply(weight, bias: Optional[torch.Tensor] = None): layout = layout_arg scale_dtype = None + tensor_quantizer = to_affine_quantized_intx + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): assert ( weight.device == torch.device("cpu") @@ -613,26 +619,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None): if torch.backends.kleidiai.is_available(): if isinstance(granularity, PerGroup): scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype - - quant_min = -(1 << (bit_width - 1)) - quant_max = (1 << (bit_width - 1)) - 1 - weight = to_affine_quantized_intx( - weight, - mapping_type=weight_mapping_type, - block_size=(1, group_size), - target_dtype=torch.int32, - quant_min=quant_min, - quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, - scale_dtype=scale_dtype, - zero_point_dtype=torch.int8, - preserve_zero=has_weight_zeros, - zero_point_domain=ZeroPointDomain.INT - if has_weight_zeros - else ZeroPointDomain.NONE, - _layout=layout, - bias=bias - ) + tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx + + quantizer_args = [weight, + weight_mapping_type, + (1, group_size), + torch.int32, + quant_min, + quant_max, + torch.finfo(torch.float32).eps, + scale_dtype, + torch.int8, + has_weight_zeros, + ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, + layout, + False] + ([bias] if propagate_bias else []) + + weight = tensor_quantizer(*quantizer_args) # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused # with the kernel and it should not be applied separately @@ -650,7 +653,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None): weight = to_linear_activation_quantized(weight, activation_quant_func) return weight - propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)