diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index 7f0970f792..a178c9b328 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -98,6 +98,37 @@ quantize_( ) ``` +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.quantization.quant_primitives import MappingType + +my_model = Model() + +quantize_( + my_model, + int8_dynamic_activation_intx_weight( + weight_dtype=torch.int4, + granularity=PerGroup(32), # PerRow() is also supported + has_weight_zeros=True, # Should be True + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), + ), +) +``` + If you get stuck, consult `torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` for a working example. diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 7b2b1da145..9d42596793 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. import logging +from enum import Enum, auto from typing import Optional, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -19,6 +22,13 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, + MappingType, + choose_qparams_affine, + quantize_affine, +) + +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, ) logger = logging.getLogger(__name__) @@ -31,17 +41,33 @@ handler.setFormatter(formatter) logger.addHandler(handler) +class Target(Enum): + """Enum that indicates the backend target""" + + NATIVE = auto() + ATEN = auto() + +def target_from_str(target: str) -> Target: + if target.lower() == "native": + return Target.NATIVE + elif target.lower() == "aten": + return Target.ATEN + else: + raise ValueError(f"Invalid target: {target}") class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): bit_width: Optional[int] group_size: Optional[int] has_weight_zeros: Optional[bool] + # The target platform for the layout, 'native' or 'aten' + target: Optional[Target] def __init__( self, bit_width: Optional[int] = None, group_size: Optional[int] = None, has_weight_zeros: Optional[bool] = None, + target: Optional[str] = "native", ): if bit_width is not None: assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" @@ -51,6 +77,7 @@ def __init__( self.bit_width = bit_width self.group_size = group_size self.has_weight_zeros = has_weight_zeros + self.target = target_from_str(target) if not self.has_params_set(): assert ( @@ -60,13 +87,14 @@ def __init__( ), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" def extra_repr(self): - return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" + return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}" def has_params_set(self) -> bool: return ( (self.bit_width is not None) and (self.group_size is not None) and (self.has_weight_zeros is not None) + and (self.target is not None) ) @@ -125,9 +153,11 @@ def from_plain( scale: torch.Tensor, zero_point: Optional[torch.Tensor], layout: Layout, + bias: Optional[torch.Tensor] = None, ): assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" + assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor # when AOTI supports int @@ -136,6 +166,13 @@ def from_plain( n_tensor = torch.empty(0, n, dtype=torch.int8) k_tensor = torch.empty(0, k, dtype=torch.int8) + if layout.target == Target.ATEN: + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + int_data = int_data.add(8) + int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8) + packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n) + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + if layout.has_weight_zeros: args = [ int_data.to(torch.int8), @@ -211,16 +248,13 @@ def __tensor_unflatten__( def _linear_check(input_tensor, weight_tensor, bias): layout = weight_tensor.tensor_impl.get_layout() return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( - bias is None + bias is None or layout.target == Target.ATEN # Aten target allows bias ) def _linear_impl(input_tensor, weight_tensor, bias): - assert ( - bias is None - ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" - def _impl_2d(input_tensor, weight_tensor): + def _impl_2d_native(input_tensor, weight_tensor): assert input_tensor.dim() == 2 assert weight_tensor.dim() == 2 @@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor): torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" )(*args) + def _impl_2d_aten(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + return torch.ops.aten._dyn_quant_matmul_4bit( + input_tensor, packed_weight, group_size, k_, n) + + target = weight_tensor.tensor_impl.get_layout().target + + if target == Target.ATEN: + assert ( + TORCH_VERSION_AT_LEAST_2_6 == 1 + ), "Target.ATEN requires torch >= 2.6.0" + _impl_2d = _impl_2d_aten + elif target == Target.NATIVE: + _impl_2d = _impl_2d_native + assert ( + bias is None + ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' " + if input_tensor.dim() == 2: return _impl_2d(input_tensor, weight_tensor) @@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor): res = res.reshape(*lead_shape, m, n) return res - register_aqt_quantized_linear_dispatch( _linear_check, _linear_impl, ) + + +class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): + """ + PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. + """ + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), + use_hqq: bool = False, + bias: Optional[torch.Tensor] = None + ): + assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" + assert isinstance( + _layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" + assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None + # TODO should probably consolidate ZeroPointDomain.NONE and None + if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + zero_point = None + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + # Note: output will be uint8 tensor for sub byte tensors for now + + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + +to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 4e0906d0a0..e77d09d98b 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import sys import logging from typing import Optional, Union @@ -18,14 +19,18 @@ PerGroup, PerRow, ) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) +from torchao.dtypes import PlainLayout logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -import sys handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -489,6 +494,8 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, + to_packedlinearint8dynamicactivationintxweight_quantized_intx, + Target, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -508,7 +515,7 @@ def int8_dynamic_activation_intx_weight( has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow ): """ Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. @@ -531,13 +538,25 @@ def int8_dynamic_activation_intx_weight( - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32) - act_mapping_type must be MappingType.ASYMMETRIC """ - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except AttributeError: - raise Exception( - "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." - + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." - ) + + def is_torchao_op_skippable(layout): + return ( + isinstance(layout, PlainLayout) or + ( + isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and + layout.target == Target.ATEN + ) + ) + + if not is_torchao_op_skippable(layout): + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " You can also set target to 'aten' if you are using ARM CPU." + + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." + ) dtype_to_bit_width = { torch.int1: 1, @@ -555,8 +574,9 @@ def int8_dynamic_activation_intx_weight( ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN - def apply(weight): + def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): group_size = granularity.group_size elif isinstance(granularity, PerRow): @@ -569,6 +589,11 @@ def apply(weight): assert weight.shape[-1] % group_size == 0 layout = layout_arg + scale_dtype = None + tensor_quantizer = to_affine_quantized_intx + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): assert ( weight.device == torch.device("cpu") @@ -584,25 +609,40 @@ def apply(weight): bit_width=bit_width, group_size=group_size, has_weight_zeros=has_weight_zeros, + target="aten" if layout.target == Target.ATEN else "native", ) - - quant_min = -(1 << (bit_width - 1)) - quant_max = (1 << (bit_width - 1)) - 1 - weight = to_affine_quantized_intx( - weight, - mapping_type=weight_mapping_type, - block_size=(1, group_size), - target_dtype=torch.int32, - quant_min=quant_min, - quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, - zero_point_dtype=torch.int8, - preserve_zero=has_weight_zeros, - zero_point_domain=ZeroPointDomain.INT - if has_weight_zeros - else ZeroPointDomain.NONE, - _layout=layout, - ) + if layout.target == Target.ATEN: + if weight_dtype != torch.int4 or \ + has_weight_zeros != True or \ + weight_mapping_type == MappingType.ASYMMETRIC: + raise NotImplementedError( + f"target 'aten' requires:\n" + f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" + f"- has_weight_zeros to be True,\n" + f"- weight_dtype to be torch.int4,\n" + f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR" + ) + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + if torch.backends.kleidiai.is_available(): + if isinstance(granularity, PerGroup): + scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype + tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx + + quantizer_args = [weight, + weight_mapping_type, + (1, group_size), + torch.int32, + quant_min, + quant_max, + torch.finfo(torch.float32).eps, + scale_dtype, + torch.int8, + has_weight_zeros, + ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, + layout, + False] + ([bias] if propagate_bias else []) + + weight = tensor_quantizer(*quantizer_args) # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused # with the kernel and it should not be applied separately @@ -620,7 +660,7 @@ def apply(weight): weight = to_linear_activation_quantized(weight, activation_quant_func) return weight - return _get_linear_subclass_inserter(apply) + return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias) class UIntxWeightOnlyQuantizedLinear(nn.Module): diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py new file mode 100644 index 0000000000..c1c5ed771e --- /dev/null +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch + +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass +from torchao.quantization.quant_primitives import MappingType + + +class TestPackedLinearInt8DynamicActivationIntxWeightLayoutAten(unittest.TestCase): + def test_accuracy(self): + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ + granularities = [PerRow()] + m = 32 + n = 128 + k = 256 + activations = torch.randn(m, k) + weight_mapping_type = MappingType.SYMMETRIC_NO_CLIPPING_ERR + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for weight_dtype in [ + torch.int4, + ]: + for has_weight_zeros in [True]: + for granularity in granularities: + print( + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={ + has_weight_zeros}, granularity={granularity}" + ) + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + weight_mapping_type=weight_mapping_type, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="aten"), # default + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PlainLayout(), + ), + ) + + with torch.no_grad(): + res = quantized_model(activations) + ref = quantized_model_reference(activations) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.04) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 02af4ced91..bbe9b1cb6b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,13 +450,15 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs): +def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ def insert_subclass(lin): requires_grad = allow_requires_grad and lin.weight.requires_grad + if propagate_bias == True: + kwargs["bias"] = lin.bias lin.weight = torch.nn.Parameter( constructor(lin.weight, **kwargs), requires_grad=requires_grad )