From f04f3b2a865c00d35985aca2433e30a9d8cbe564 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 13 Jan 2025 14:07:38 -0800 Subject: [PATCH] Clean up linear_int8_dynamic_activation_intx_weight_subclass (#1553) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1553 Cleans up layout and quantization API: ``` int8_dynamic_activation_intx_weight( group_size: int = 128, bit_width: int = 4, has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ) ``` int8_dynamic_activation_intx_weight is now very similar to int8_dynamic_activation_int4_weight. By passing bit_width=4, has_weight_zeros=false, and layout=PlainLayout(), it should be numerically identical (but slower). The fallback option is removed and instead relies on using PlainLayout(). Reviewed By: jerryzh168 Differential Revision: D67821939 --- torchao/_models/llama/generate.py | 22 +- torchao/dtypes/uintx/plain_layout.py | 18 +- torchao/dtypes/utils.py | 6 +- .../_linear_8bit_act_xbit_weight_layout.py | 374 ------------------ torchao/experimental/docs/readme.md | 63 ++- ...8_dynamic_activation_intx_weight_layout.py | 275 +++++++++++++ torchao/experimental/quant_api.py | 172 ++++++-- ..._dynamic_activation_intx_weight_layout.py} | 97 +++-- 8 files changed, 528 insertions(+), 499 deletions(-) delete mode 100644 torchao/experimental/_linear_8bit_act_xbit_weight_layout.py create mode 100644 torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py rename torchao/experimental/tests/{test_linear_int8_dynamic_activation_intx_weight_subclass.py => test_packed_linear_int8_dynamic_activation_intx_weight_layout.py} (56%) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5635ed8d23..b1d3475601 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -543,32 +543,22 @@ def ffn_or_attn_only(mod, fqn): from torchao.experimental.quant_api import ( int8_dynamic_activation_intx_weight, ) + from torchao.quantization.granularity import PerGroup assert ( precision == torch.float32 - ), "int8_dynamic_activation_intx_weight requires fp32 precision" - - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except: - print( - "Unable to load experimental torchao kernels. Performance will be slow." - ) - print( - "To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU" - ) + ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32" # Quantize model _quant_args = quantization.split("-") - nbit = int(_quant_args[1]) - assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8" - group_size = int(_quant_args[2]) + weight_dtype = getattr(torch, f"int{_quant_args[1]}") + granularity = PerGroup(int(_quant_args[2])) has_weight_zeros = bool(_quant_args[3]) quantize_( model, int8_dynamic_activation_intx_weight( - group_size=group_size, - nbit=nbit, + weight_dtype=weight_dtype, + granularity=granularity, has_weight_zeros=has_weight_zeros, ), ) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 502e3c13e9..f47757fb77 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -38,7 +38,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): kwargs = {} @@ -55,7 +55,7 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): self.int_data = int_data @@ -64,6 +64,8 @@ def __init__( self._layout = _layout def __tensor_flatten__(self): + if self.zero_point is None: + return ["int_data", "scale"], [self._layout] return ["int_data", "scale", "zero_point"], [self._layout] @classmethod @@ -73,7 +75,7 @@ def __tensor_unflatten__( int_data, scale, zero_point = ( tensor_data_dict["int_data"], tensor_data_dict["scale"], - tensor_data_dict["zero_point"], + tensor_data_dict.get("zero_point", None), ) (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) @@ -83,7 +85,9 @@ def to(self, *args, **kwargs): return self.__class__( self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]) + if self.zero_point is not None + else None, self._layout, ) @@ -91,7 +95,7 @@ def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), - fn(self.zero_point), + fn(self.zero_point) if self.zero_point is not None else None, self._layout, ) @@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return PlainAQTTensorImpl( aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), - self.zero_point.view(-1), + self.zero_point.view(-1) if self.zero_point is not None else None, self._layout, ) else: @@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.int_data, self.scale, self.zero_point def get_layout(self) -> Layout: diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 774071f856..0952b2a4bf 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor): the underlying implementation of a AQT based on layout """ - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Get the plain (unpacked) Tensor for the tensor impl Returns data, scale and zero_point @@ -103,7 +103,7 @@ def from_plain( cls, data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): """Construct a TensorImpl from data, scale, zero_point and the _layout""" diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py deleted file mode 100644 index 1f24c91ed2..0000000000 --- a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from enum import Enum, auto -from typing import Optional, Tuple - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.dtypes.affine_quantized_tensor import ( - register_layout, -) -from torchao.dtypes.affine_quantized_tensor_ops import ( - register_aqt_quantized_linear_dispatch, -) -from torchao.dtypes.utils import AQTTensorImpl, Layout -from torchao.quantization.quant_api import to_affine_quantized_intx -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) - -import sys - -handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) - - -class Target(Enum): - """Enum that indicates the backend target""" - - NATIVE = auto() - FALLBACK = auto() - - -def target_from_str(target: str) -> Target: - if target.lower() == "native": - return Target.NATIVE - elif target.lower() == "fallback": - return Target.FALLBACK - else: - raise ValueError(f"Invalid target: {target}") - - -# This format is intended for use with int8 dynamic quantization -class Linear8BitActXBitWeightLayout(Layout): - nbit: int - group_size: int - - # The target platform for the layout, either 'native' or 'fallback'. - target: Target - - def __init__( - self, - nbit: int, - group_size: int, - target: str, - ): - assert nbit <= 8 - self.nbit = nbit - self.group_size = group_size - self.target = target_from_str(target) - - def extra_repr(self): - return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" - - -def _pack_weights_native( - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout: Layout, -): - assert isinstance(layout, Linear8BitActXBitWeightLayout) - assert layout.target == Target.NATIVE - nbit = layout.nbit - group_size = layout.group_size - has_weight_zeros = zero_point is not None - - if has_weight_zeros: - args = [ - int_data.to(torch.int8), - scale.reshape(-1), - zero_point.reshape(-1).to(torch.int8), - torch.empty(0, group_size, dtype=torch.int8), - ] - else: - args = [ - int_data.to(torch.int8), - scale.reshape(-1), - torch.empty(0, group_size, dtype=torch.int8), - ] - - wzp_suffix = "" if has_weight_zeros else "0zp" - return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( - *args - ) - - -@register_layout(Linear8BitActXBitWeightLayout) -class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl): - def __new__( - cls, - packed_weight: torch.Tensor, - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["dtype"] = packed_weight.dtype - assert not packed_weight.requires_grad - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, Linear8BitActXBitWeightLayout) - - # In the native case, scale and zero_point information is inside - # the packed_weight - if _layout.target == Target.NATIVE: - assert scale is None - assert zero_point is None - - self.packed_weight = packed_weight - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __repr__(self): - layout = self.get_layout() - return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})" - - def get_layout(self) -> Layout: - return self._layout - - def get_plain( - self, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if self.get_layout().target == Target.FALLBACK: - return self.packed_weight, self.scale, self.zero_point - raise NotImplementedError( - "get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback" - ) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout: Layout, - ): - assert isinstance(layout, Linear8BitActXBitWeightLayout) - - try: - if layout.target == Target.NATIVE: - packed_weight = _pack_weights_native( - int_data, scale, zero_point, layout - ) - scale = None - zero_point = None - return cls(packed_weight, scale, zero_point, layout) - except Exception as e: - logger.warning( - f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n" - + "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback." - ) - layout.target = Target.FALLBACK - - # Fallback - assert layout.target == Target.FALLBACK - packed_weight = int_data.to(torch.int32) - return cls(packed_weight, scale, zero_point, layout) - - def _apply_fn_to_data(self, fn): - self.packed_weight = fn(self.packed_weight) - if self.scale is not None: - self.scale = fn(self.scale) - - if self.zero_point is not None: - self.zero_point = fn(self.zero_point) - return self - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is torch.ops.aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is torch.ops.aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - raise NotImplementedError( - f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - if self.get_layout().target == Target.NATIVE: - return ["packed_weight"], [self.get_layout()] - - # fallback - assert self.get_layout().target == Target.FALLBACK - if self.zero_point is None: - return ["packed_weight", "scale"], [self.get_layout()] - return ["packed_weight", "scale", "zero_point"], [self.get_layout()] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale, zero_point = ( - tensor_data_dict["packed_weight"], - tensor_data_dict.get("scale", None), - tensor_data_dict.get("zero_point", None), - ) - (layout,) = tensor_attributes - return cls(packed_weight, scale, zero_point, layout) - - -def _linear_int8_dynamic_activation_intx_weight_check( - input_tensor, weight_tensor, bias -): - layout = weight_tensor.tensor_impl.get_layout() - return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None - - -def _linear_int8_dynamic_activation_intx_weight_fallback_impl( - input_tensor, weight_tensor, bias -): - assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK - assert bias is None - - def _impl_2d(input_tensor, weight_tensor): - assert input_tensor.dim() == 2 - assert weight_tensor.dim() == 2 - - m, k = input_tensor.shape - n, k_ = weight_tensor.shape - assert k_ == k - - weights_dequantized = weight_tensor.dequantize() - - # Quantize activations - activations_dequantized = to_affine_quantized_intx( - input_tensor, - mapping_type=MappingType.ASYMMETRIC, - block_size=(1, k), - target_dtype=torch.int32, - quant_min=-128, - quant_max=127, - eps=0.0, - zero_point_dtype=torch.int32, - preserve_zero=True, - zero_point_domain=ZeroPointDomain.INT, - use_hqq=False, - ).dequantize() - - return torch.matmul( - activations_dequantized, weights_dequantized.transpose(1, 0) - ) - - if input_tensor.dim() == 2: - return _impl_2d(input_tensor, weight_tensor) - - assert input_tensor.dim() >= 3 - lead_shape = input_tensor.shape[0:-2] - m, k = input_tensor.shape[-2], input_tensor.shape[-1] - n, k_ = weight_tensor.shape - assert k_ == k - - res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) - res = res.reshape(*lead_shape, m, n) - - return res - - -def _linear_int8_dynamic_activation_intx_weight_native_impl( - input_tensor, weight_tensor, bias -): - assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE - assert bias is None - - def _impl_2d(input_tensor, weight_tensor): - assert input_tensor.dim() == 2 - assert weight_tensor.dim() == 2 - - m, k = input_tensor.shape - n, k_ = weight_tensor.shape - assert k_ == k - group_size = weight_tensor.tensor_impl.get_layout().group_size - packed_weight = weight_tensor.tensor_impl.packed_weight - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - args = ( - input_tensor, - packed_weight, - torch.empty(0, group_size, dtype=torch.int8), - torch.empty(0, n, dtype=torch.int8), - torch.empty(0, k, dtype=torch.int8), - ) - - has_weight_zeros = weight_tensor.zero_point_domain != ZeroPointDomain.NONE - - assert len(weight_tensor.block_size) == 2 - assert weight_tensor.block_size[0] == 1 - group_size = weight_tensor.block_size[1] - assert group_size == weight_tensor.tensor_impl.get_layout().group_size - nbit = weight_tensor.tensor_impl.get_layout().nbit - - n, k = weight_tensor.shape - m, k_ = input_tensor.shape - assert k_ == k - - packed_weight = weight_tensor.tensor_impl.packed_weight - wzp_suffix = "" if has_weight_zeros else "0zp" - return getattr( - torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" - )(*args) - - if input_tensor.dim() == 2: - return _impl_2d(input_tensor, weight_tensor) - - assert input_tensor.dim() >= 3 - lead_shape = input_tensor.shape[0:-2] - m, k = input_tensor.shape[-2], input_tensor.shape[-1] - n, k_ = weight_tensor.shape - assert k_ == k - - res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) - res = res.reshape(*lead_shape, m, n) - return res - - -def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): - target = weight_tensor.tensor_impl.get_layout().target - if target == Target.NATIVE: - return _linear_int8_dynamic_activation_intx_weight_native_impl( - input_tensor, weight_tensor, bias - ) - - if target == Target.FALLBACK: - return _linear_int8_dynamic_activation_intx_weight_fallback_impl( - input_tensor, weight_tensor, bias - ) - - assert False, f"Unknown target {target}" - - -register_aqt_quantized_linear_dispatch( - _linear_int8_dynamic_activation_intx_weight_check, - _linear_int8_dynamic_activation_intx_weight_impl, -) diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index c1bfa5c32a..7f0970f792 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -1,21 +1,29 @@ # TorchAO experimental -TorchAO experimental contains lowbit ARM CPU and Metal kernels for linear and embedding ops. +TorchAO experimental contains lowbit ARM CPU and Metal kernels for linear and +embedding ops. ## Building ARM CPU kernels -To build torch ops that use the lowbit kernels, run `sh build_torchao_ops.sh ` from torchao/experimental. +To build torch ops that use the lowbit kernels, run +`sh build_torchao_ops.sh ` from torchao/experimental. -For example, to build ATen ops, run `sh build_torchao_ops.sh aten` (this requires PyTorch). Similarly, to build the ExecuTorch ops, run `sh build_torchao_ops executorch` (this requires ExecuTorch). +For example, to build ATen ops, run `sh build_torchao_ops.sh aten` (this +requires PyTorch). Similarly, to build the ExecuTorch ops, run +`sh build_torchao_ops executorch` (this requires ExecuTorch). After running the script, the op libraries will be in + ``` cmake-out/lib/libtorchao_ops_aten.{dylib|so} # ATen op library cmake-out/lib/libtorchao_ops_executorch.a # ExecuTorch op library ``` ## Quantizing models -Once the ATen ops are built, you can quantize PyTorch models with them. The quantized models can be run in eager model, compiled, used with AOTI, or exported. The exported models can be lowered to ExecuTorch. + +Once the ATen ops are built, you can quantize PyTorch models with them. The +quantized models can be run in eager model, compiled, used with AOTI, or +exported. The exported models can be lowered to ExecuTorch. ```python import torch @@ -43,33 +51,60 @@ linear_quantizer = Int8DynActIntxWeightLinearQuantizer( quantized_model = linear_quantizer.quantize(quantized_model) ``` -If you get stuck on the above steps, working examples for both linear and embedding are in torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the ops, creates a toy model, quantizes the model, and runs it in eager, compile, AOTI, and exports the model. +If you get stuck on the above steps, working examples for both linear and +embedding are in +torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py and +torchao/experimental/tests/test_embedding_xbit_quantizer.py. For example, +running `python tests/test_linear_8bit_act_xbit_weight_quantizer.py` loads the +ops, creates a toy model, quantizes the model, and runs it in eager, compile, +AOTI, and exports the model. ### Subclass API -For linear, you can also use the new subclass API in torchao. +For linear, you can also use the new subclass API in torchao. First install the +kernels by running the following command from the ao directory. (Note: takeshis +will only install the kernels if run on a Mac with Apple Silicon.) + +``` +USE_CPP=1 pip install . +``` + +Once the kernels are installed, you can quantize your model as follows: ```python -import torch -torch.ops.load_library("cmake-out/lib/libtorchao_ops_aten.dylib") # make sure this path is correct on your machine +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ my_model = Model() -from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight -from torchao.quantization.quant_api import quantize_ quantize_( my_model, int8_dynamic_activation_intx_weight( - group_size=256, - nbit=4, + weight_dtype=torch.int4, + granularity=PerGroup(256), # PerRow() is also supported has_weight_zeros=False, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() is also supported, but much slower on CPU ), ) ``` If you get stuck, consult -`tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py`. +`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` +for a working example. ## Available in torchchat -TorchAO experimental kernels are [available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), PyTorch's solution for running LLMs locally. Torchchat integration uses similar steps to above. +TorchAO experimental kernels are +[available in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels), +PyTorch's solution for running LLMs locally. Torchchat integration uses similar +steps to above. diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py new file mode 100644 index 0000000000..7b2b1da145 --- /dev/null +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.affine_quantized_tensor import ( + register_layout, +) +from torchao.dtypes.affine_quantized_tensor_ops import ( + register_aqt_quantized_linear_dispatch, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): + bit_width: Optional[int] + group_size: Optional[int] + has_weight_zeros: Optional[bool] + + def __init__( + self, + bit_width: Optional[int] = None, + group_size: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if bit_width is not None: + assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" + if group_size is not None: + assert group_size >= 1, f"group_size must be positive, got {group_size}" + + self.bit_width = bit_width + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + if not self.has_params_set(): + assert ( + self.bit_width is None + and self.group_size is None + and self.has_weight_zeros is None + ), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" + + def extra_repr(self): + return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" + + def has_params_set(self) -> bool: + return ( + (self.bit_width is not None) + and (self.group_size is not None) + and (self.has_weight_zeros is not None) + ) + + +@register_layout(PackedLinearInt8DynamicActivationIntxWeightLayout) +class PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl(AQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + _layout: Layout, + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor: torch.Tensor, + n_tensor: torch.Tensor, + k_tensor: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = packed_weight.dtype + assert not packed_weight.requires_grad + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + _layout: Layout, + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor: torch.Tensor, + n_tensor: torch.Tensor, + k_tensor: torch.Tensor, + ): + assert isinstance(_layout, PackedLinearInt8DynamicActivationIntxWeightLayout) + self.packed_weight = packed_weight + self._layout = _layout + self.group_size_tensor = group_size_tensor + self.n_tensor = n_tensor + self.k_tensor = k_tensor + + def __repr__(self): + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, layout={self.get_layout()})" + + def get_layout(self) -> Layout: + return self._layout + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError( + "get_plain is not implemented for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" + ) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + layout: Layout, + ): + assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) + assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + n, k = int_data.shape + group_size_tensor = torch.empty(0, layout.group_size, dtype=torch.int8) + n_tensor = torch.empty(0, n, dtype=torch.int8) + k_tensor = torch.empty(0, k, dtype=torch.int8) + + if layout.has_weight_zeros: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + zero_point.reshape(-1).to(torch.int8), + group_size_tensor, + ] + else: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + group_size_tensor, + ] + + wzp_suffix = "" if layout.has_weight_zeros else "0zp" + packed_weight = getattr( + torch.ops.torchao, + f"_pack_8bit_act_{layout.bit_width}bit{wzp_suffix}_weight", + )(*args) + + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + self.group_size_tensor = fn(self.group_size_tensor) + self.n_tensor = fn(self.n_tensor) + self.k_tensor = fn(self.k_tensor) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is torch.ops.aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + return ["packed_weight", "group_size_tensor", "n_tensor", "k_tensor"], [ + self.get_layout() + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight = tensor_data_dict["packed_weight"] + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor = tensor_data_dict["group_size_tensor"] + n_tensor = tensor_data_dict["n_tensor"] + k_tensor = tensor_data_dict["k_tensor"] + + (layout,) = tensor_attributes + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + + +def _linear_check(input_tensor, weight_tensor, bias): + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( + bias is None + ) + + +def _linear_impl(input_tensor, weight_tensor, bias): + assert ( + bias is None + ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + + assert group_size == weight_tensor.tensor_impl.group_size_tensor.shape[1] + assert n == weight_tensor.tensor_impl.n_tensor.shape[1] + assert k == weight_tensor.tensor_impl.k_tensor.shape[1] + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + args = ( + input_tensor, + weight_tensor.tensor_impl.packed_weight, + weight_tensor.tensor_impl.group_size_tensor, + weight_tensor.tensor_impl.n_tensor, + weight_tensor.tensor_impl.k_tensor, + ) + + has_weight_zeros = weight_tensor.zero_point_domain != ZeroPointDomain.NONE + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + assert group_size == weight_tensor.block_size[1] + bit_width = weight_tensor.tensor_impl.get_layout().bit_width + + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" + )(*args) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + return res + + +register_aqt_quantized_linear_dispatch( + _linear_check, + _linear_impl, +) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index ce99e250ef..4e0906d0a0 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -14,6 +14,11 @@ quantize_per_channel_group, ) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -482,58 +487,139 @@ def quantize(self, model: nn.Module) -> nn.Module: return model +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_api import ( + MappingType, + ZeroPointDomain, + _get_linear_subclass_inserter, + to_affine_quantized_intx, +) +from torchao.quantization.utils import _get_per_token_block_size + + def int8_dynamic_activation_intx_weight( - group_size: int = 128, - nbit: int = 4, + weight_dtype: torch.dtype = torch.int4, + granularity: Union[PerRow, PerGroup] = PerGroup(128), has_weight_zeros: bool = False, - target: str = "native", + weight_mapping_type=MappingType.ASYMMETRIC, + act_mapping_type=MappingType.ASYMMETRIC, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow ): - from torchao.experimental._linear_8bit_act_xbit_weight_layout import ( - Linear8BitActXBitWeightLayout, - ) - from torchao.quantization.quant_api import ( - MappingType, - ZeroPointDomain, - _get_linear_subclass_inserter, - to_affine_quantized_intx, - ) + """ + Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. + More specifically, activations are dynamically quantized to 8-bits in a channelwise manner with scales and zeros. + Weights are quantized with scales and optionally zeros (controlled by has_weight_zeros) in a groupwise or channelwise + manner using the number of bits specified by weight_dtype. + + args: + weight_dtype: The dtype to use for weight quantization. Must be torch.intx, where 1 <= x <= 8. + granularity: The granularity to use for weight quantization. Must be PerGroup or PerRow. + has_weight_zeros: Whether or not to include zeros in the weight quantization. + weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + act_mapping_type: The type of mapping to use for the activation quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + layout: The layout to use for the packed weight tensor. Must be PackedLinearInt8DynamicActivationIntxWeightLayout (default) or PlainLayout. + The layout does not affect the quantization numerically and both layouts will give the same results. PlainLayout is a generic layout + that works on all devices, but it is much slower than PackedLinearInt8DynamicActivationIntxWeightLayout on CPU. + PackedLinearInt8DynamicActivationIntxWeightLayout is a specialized layout for CPU performance. + When using PackedLinearInt8DynamicActivationIntxWeightLayout, + - The weight tensor must have device=CPU + - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32) + - act_mapping_type must be MappingType.ASYMMETRIC + """ + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." + ) + + dtype_to_bit_width = { + torch.int1: 1, + torch.int2: 2, + torch.int3: 3, + torch.int4: 4, + torch.int5: 5, + torch.int6: 4, + torch.int7: 7, + torch.int8: 8, + } + if weight_dtype not in dtype_to_bit_width: + raise ValueError( + f"weight_dtype must be one of {list(dtype_to_bit_width.keys())}, got {weight_dtype}" + ) + bit_width = dtype_to_bit_width[weight_dtype] + layout_arg = layout def apply(weight): + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + elif isinstance(granularity, PerRow): + group_size = weight.shape[-1] + else: + raise ValueError( + f"granularity must be PerGroup or PerRow, got {granularity}" + ) + assert weight.shape[-1] % group_size == 0 - assert weight.device == torch.device("cpu"), "Only CPU is supported" - use_hqq = False - layout = Linear8BitActXBitWeightLayout( - nbit=nbit, group_size=group_size, target=target - ) - mapping_type = MappingType.ASYMMETRIC - eps = torch.finfo(torch.float32).eps - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = -(1 << (nbit - 1)) - quant_max = (1 << (nbit - 1)) - 1 - zero_point_dtype = torch.int8 - preserve_zero = has_weight_zeros - zero_point_domain = ( - ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE - ) - # Note: this works differently than other quantizers because the dynamic - # activation quantization is fused with the kernel/op (and static activation quantization - # is not supported). - return to_affine_quantized_intx( + + layout = layout_arg + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): + assert ( + weight.device == torch.device("cpu") + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU" + assert ( + weight.dtype == torch.float32 + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.dtype=float32" + assert ( + act_mapping_type == MappingType.ASYMMETRIC + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" + assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set" + layout = PackedLinearInt8DynamicActivationIntxWeightLayout( + bit_width=bit_width, + group_size=group_size, + has_weight_zeros=has_weight_zeros, + ) + + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + weight = to_affine_quantized_intx( weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, + mapping_type=weight_mapping_type, + block_size=(1, group_size), + target_dtype=torch.int32, + quant_min=quant_min, + quant_max=quant_max, + eps=torch.finfo(torch.float32).eps, + zero_point_dtype=torch.int8, + preserve_zero=has_weight_zeros, + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, _layout=layout, - use_hqq=use_hqq, ) + # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused + # with the kernel and it should not be applied separately + if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): + activation_quant_func = lambda x: to_affine_quantized_intx( + x, + mapping_type=act_mapping_type, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int32, + quant_min=-128, # lower bound of int8 + quant_max=127, # upper bound of int8 + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + ) + weight = to_linear_activation_quantized(weight, activation_quant_func) + return weight + return _get_linear_subclass_inserter(apply) diff --git a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py similarity index 56% rename from torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py rename to torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py index 61f6c6cc01..284ef4b2a8 100644 --- a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -10,33 +10,56 @@ import torch +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) from torchao.experimental.quant_api import ( - _Int8DynActIntxWeightQuantizedLinearFallback, int8_dynamic_activation_intx_weight, ) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) from torchao.quantization.quant_api import quantize_ from torchao.utils import unwrap_tensor_subclass -class TestInt8DynamicActivationIntxWeight(unittest.TestCase): +class TestPackedLinearInt8DynamicActivationIntxWeightLayout(unittest.TestCase): def test_accuracy(self): - group_size = 128 + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ + granularity = PerGroup(128) m = 1 n = 1071 k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) + activations = torch.randn(m, k) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: + for weight_dtype in [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ]: for has_weight_zeros in [True, False]: - print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + print( + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}" + ) quantized_model = copy.deepcopy(model) quantize_( quantized_model, int8_dynamic_activation_intx_weight( - group_size=group_size, - nbit=nbit, + weight_dtype=weight_dtype, + granularity=granularity, has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # default ), ) @@ -44,10 +67,10 @@ def test_accuracy(self): quantize_( quantized_model_reference, int8_dynamic_activation_intx_weight( - group_size=group_size, - nbit=nbit, + weight_dtype=weight_dtype, + granularity=granularity, has_weight_zeros=has_weight_zeros, - target="fallback", + layout=PlainLayout(), ), ) @@ -55,44 +78,30 @@ def test_accuracy(self): result = quantized_model(activations) expected_result = quantized_model_reference(activations) - # TODO: remove expected_result2 checks when we deprecate non-subclass API - reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() - reference_impl.quantize_and_pack_weights( - model[0].weight, nbit, group_size, has_weight_zeros - ) - expected_result2 = reference_impl(activations) - num_mismatch_at_low_tol = 0 - num_mismatch_at_low_tol2 = 0 num_total = result.reshape(-1).shape[0] for i in range(num_total): actual_val = result.reshape(-1)[i] expected_val = expected_result.reshape(-1)[i] - expected_val2 = expected_result2.reshape(-1)[i] self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) if not torch.allclose(actual_val, expected_val): num_mismatch_at_low_tol += 1 - self.assertTrue( - torch.allclose( - expected_val, expected_val2, atol=1e-2, rtol=1e-1 - ) - ) - if not torch.allclose(expected_val, expected_val2): - num_mismatch_at_low_tol2 += 1 - # Assert at most 5% of entries are not close at a low tolerance self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - self.assertTrue(num_mismatch_at_low_tol2 / num_total <= 0.01) def test_export_compile_aoti(self): - group_size = 32 + """ + Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with + torch.export.export, torch.compile, and AOTI. + """ + granularity = PerRow() m = 3 k0 = 512 k1 = 256 k2 = 128 k3 = 1024 - nbit = 4 + weight_dtype = torch.int4 has_weight_zeros = True layers = [ torch.nn.Linear(k0, k1, bias=False), @@ -106,35 +115,39 @@ def test_export_compile_aoti(self): quantize_( model, int8_dynamic_activation_intx_weight( - group_size=group_size, - nbit=nbit, + weight_dtype=weight_dtype, + granularity=granularity, has_weight_zeros=has_weight_zeros, - target="native", + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ), ) + eager_results = model(activations) unwrapped_model = copy.deepcopy(model) unwrap_tensor_subclass(model) print("Exporting quantized model") - torch.export.export(model, (activations,), strict=True) + exported = torch.export.export(model, (activations,), strict=True) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) print("Compiling quantized model") compiled = torch.compile(unwrapped_model) with torch.no_grad(): - compiled(activations) + compiled_results = compiled(activations) + self.assertTrue(torch.allclose(eager_results, compiled_results)) with tempfile.TemporaryDirectory() as tmpdirname: + package_path = f"{tmpdirname}/model.pt2" print("Exporting quantized model with AOTI") - torch._export.aot_compile( - model, - (activations,), - options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + torch._inductor.aoti_compile_and_package( + exported, package_path=package_path ) print("Running quantized model in AOTI") - fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") - fn(activations) + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(activations) + self.assertTrue(torch.allclose(eager_results, aoti_results)) if __name__ == "__main__":