-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat]: Add support for kleidiai quantization schemes #1447
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,15 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
from enum import Enum, auto | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
|
||
from torchao.dtypes.affine_quantized_tensor import ( | ||
AffineQuantizedTensor, | ||
get_tensor_impl_constructor, | ||
register_layout, | ||
) | ||
from torchao.dtypes.affine_quantized_tensor_ops import ( | ||
|
@@ -19,6 +22,13 @@ | |
from torchao.dtypes.utils import AQTTensorImpl, Layout | ||
from torchao.quantization.quant_primitives import ( | ||
ZeroPointDomain, | ||
MappingType, | ||
choose_qparams_affine, | ||
quantize_affine, | ||
) | ||
|
||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_6, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -31,17 +41,33 @@ | |
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
|
||
class Target(Enum): | ||
"""Enum that indicates the backend target""" | ||
|
||
NATIVE = auto() | ||
ATEN = auto() | ||
|
||
def target_from_str(target: str) -> Target: | ||
if target.lower() == "native": | ||
return Target.NATIVE | ||
elif target.lower() == "aten": | ||
return Target.ATEN | ||
else: | ||
raise ValueError(f"Invalid target: {target}") | ||
|
||
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): | ||
bit_width: Optional[int] | ||
group_size: Optional[int] | ||
has_weight_zeros: Optional[bool] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The packed weights from Kleidi have bias packed with them, right? If so, let's add has_bias: Optional[bool] here to layout. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can detect bias propagation from target=="aten" or from has_bias. I dont think we need both if we are making target part of PackedLinearInt8DynamicActivationIntxWeightLayout(). |
||
# The target platform for the layout, 'native' or 'aten' | ||
target: Optional[Target] | ||
|
||
def __init__( | ||
self, | ||
bit_width: Optional[int] = None, | ||
group_size: Optional[int] = None, | ||
has_weight_zeros: Optional[bool] = None, | ||
target: Optional[str] = "native", | ||
): | ||
if bit_width is not None: | ||
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" | ||
|
@@ -51,6 +77,7 @@ def __init__( | |
self.bit_width = bit_width | ||
self.group_size = group_size | ||
self.has_weight_zeros = has_weight_zeros | ||
self.target = target_from_str(target) | ||
|
||
if not self.has_params_set(): | ||
assert ( | ||
|
@@ -60,13 +87,14 @@ def __init__( | |
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" | ||
|
||
def extra_repr(self): | ||
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" | ||
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}" | ||
|
||
def has_params_set(self) -> bool: | ||
return ( | ||
(self.bit_width is not None) | ||
and (self.group_size is not None) | ||
and (self.has_weight_zeros is not None) | ||
and (self.target is not None) | ||
) | ||
|
||
|
||
|
@@ -125,9 +153,11 @@ def from_plain( | |
scale: torch.Tensor, | ||
zero_point: Optional[torch.Tensor], | ||
layout: Layout, | ||
bias: Optional[torch.Tensor] = None, | ||
): | ||
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) | ||
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" | ||
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to assert bias is None when layout.target != ATEN? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if bias is none and target layout.target != ATEN then execution will never reach this point. It will be executed via the "native" target which has bias assert. Please check here: https://github.com/pytorch/ao/pull/1447/files#diff-3e4ffa192fe9921999dd6a798fc3fa620377896ef9ba65245b1e5ab8c0d7d344R593 |
||
|
||
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor | ||
# when AOTI supports int | ||
|
@@ -136,6 +166,13 @@ def from_plain( | |
n_tensor = torch.empty(0, n, dtype=torch.int8) | ||
k_tensor = torch.empty(0, k, dtype=torch.int8) | ||
|
||
if layout.target == Target.ATEN: | ||
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" | ||
int_data = int_data.add(8) | ||
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8) | ||
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n) | ||
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) | ||
|
||
if layout.has_weight_zeros: | ||
args = [ | ||
int_data.to(torch.int8), | ||
|
@@ -211,16 +248,13 @@ def __tensor_unflatten__( | |
def _linear_check(input_tensor, weight_tensor, bias): | ||
layout = weight_tensor.tensor_impl.get_layout() | ||
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( | ||
bias is None | ||
bias is None or layout.target == Target.ATEN # Aten target allows bias | ||
) | ||
|
||
|
||
def _linear_impl(input_tensor, weight_tensor, bias): | ||
assert ( | ||
bias is None | ||
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" | ||
|
||
def _impl_2d(input_tensor, weight_tensor): | ||
def _impl_2d_native(input_tensor, weight_tensor): | ||
assert input_tensor.dim() == 2 | ||
assert weight_tensor.dim() == 2 | ||
|
||
|
@@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor): | |
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" | ||
)(*args) | ||
|
||
def _impl_2d_aten(input_tensor, weight_tensor): | ||
assert input_tensor.dim() == 2 | ||
assert weight_tensor.dim() == 2 | ||
|
||
m, k = input_tensor.shape | ||
n, k_ = weight_tensor.shape | ||
assert k_ == k | ||
group_size = weight_tensor.tensor_impl.get_layout().group_size | ||
packed_weight = weight_tensor.tensor_impl.packed_weight | ||
return torch.ops.aten._dyn_quant_matmul_4bit( | ||
input_tensor, packed_weight, group_size, k_, n) | ||
|
||
target = weight_tensor.tensor_impl.get_layout().target | ||
|
||
if target == Target.ATEN: | ||
assert ( | ||
TORCH_VERSION_AT_LEAST_2_6 == 1 | ||
), "Target.ATEN requires torch >= 2.6.0" | ||
_impl_2d = _impl_2d_aten | ||
elif target == Target.NATIVE: | ||
_impl_2d = _impl_2d_native | ||
assert ( | ||
bias is None | ||
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' " | ||
|
||
if input_tensor.dim() == 2: | ||
return _impl_2d(input_tensor, weight_tensor) | ||
|
||
|
@@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor): | |
res = res.reshape(*lead_shape, m, n) | ||
return res | ||
|
||
|
||
register_aqt_quantized_linear_dispatch( | ||
_linear_check, | ||
_linear_impl, | ||
) | ||
|
||
|
||
class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a subclass of AffineQuantizedTensor? What is wrong with using the existing to_affine_quantized_intx method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is suggested by one of the Pytorch reviewer: #1447 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, (1) why adding bias as optional |
||
""" | ||
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. | ||
""" | ||
|
||
@classmethod | ||
def from_hp_to_intx( | ||
cls, | ||
input_float: torch.Tensor, | ||
mapping_type: MappingType, | ||
block_size: Tuple[int, ...], | ||
target_dtype: torch.dtype, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
eps: Optional[float] = None, | ||
scale_dtype: Optional[torch.dtype] = None, | ||
zero_point_dtype: Optional[torch.dtype] = None, | ||
preserve_zero: bool = True, | ||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, | ||
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), | ||
use_hqq: bool = False, | ||
bias: Optional[torch.Tensor] = None | ||
): | ||
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" | ||
assert isinstance( | ||
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" | ||
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." | ||
original_shape = input_float.shape | ||
input_float = _layout.pre_process(input_float) | ||
|
||
scale, zero_point = choose_qparams_affine( | ||
input_float, | ||
mapping_type, | ||
block_size, | ||
target_dtype, | ||
quant_min, | ||
quant_max, | ||
eps, | ||
scale_dtype, | ||
zero_point_dtype, | ||
preserve_zero, | ||
zero_point_domain, | ||
) | ||
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None | ||
# TODO should probably consolidate ZeroPointDomain.NONE and None | ||
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: | ||
zero_point = None | ||
data = quantize_affine( | ||
input_float, | ||
block_size, | ||
scale, | ||
zero_point, | ||
target_dtype, | ||
quant_min, | ||
quant_max, | ||
zero_point_domain, | ||
) | ||
# Note: output will be uint8 tensor for sub byte tensors for now | ||
|
||
data = _layout.post_process(data) | ||
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) | ||
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) | ||
return cls( | ||
tensor_impl, | ||
block_size, | ||
original_shape, | ||
quant_min, | ||
quant_max, | ||
zero_point_domain, | ||
dtype=input_float.dtype, | ||
) | ||
|
||
to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What makes it ATEN specific? Should it be Kleidi? I am thinking from longer term perspective where we will use this layout arg to differentiate how to pack weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to stick to aten so that more ops from aten can be added in future if needed. This enables this layout to work with torchao ops and aten ops.
dyn_quant_matmul_4bit
is not supposed to be KleidiAI specific only