Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]: Add support for kleidiai quantization schemes #1447

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
# LICENSE file in the root directory of this source tree.

import logging
from enum import Enum, auto
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
Expand All @@ -19,6 +22,13 @@
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
MappingType,
choose_qparams_affine,
quantize_affine,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
Expand All @@ -31,17 +41,33 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

class Target(Enum):
"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes it ATEN specific? Should it be Kleidi? I am thinking from longer term perspective where we will use this layout arg to differentiate how to pack weights.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to stick to aten so that more ops from aten can be added in future if needed. This enables this layout to work with torchao ops and aten ops.
dyn_quant_matmul_4bit is not supposed to be KleidiAI specific only


def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The packed weights from Kleidi have bias packed with them, right? If so, let's add has_bias: Optional[bool] here to layout.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can detect bias propagation from target=="aten" or from has_bias. I dont think we need both if we are making target part of PackedLinearInt8DynamicActivationIntxWeightLayout().
Where do you want to use has_bias and what is the usecase?

# The target platform for the layout, 'native' or 'aten'
target: Optional[Target]

def __init__(
self,
bit_width: Optional[int] = None,
group_size: Optional[int] = None,
has_weight_zeros: Optional[bool] = None,
target: Optional[str] = "native",
):
if bit_width is not None:
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
Expand All @@ -51,6 +77,7 @@ def __init__(
self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.target = target_from_str(target)

if not self.has_params_set():
assert (
Expand All @@ -60,13 +87,14 @@ def __init__(
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"

def extra_repr(self):
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"

def has_params_set(self) -> bool:
return (
(self.bit_width is not None)
and (self.group_size is not None)
and (self.has_weight_zeros is not None)
and (self.target is not None)
)


Expand Down Expand Up @@ -125,9 +153,11 @@ def from_plain(
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
layout: Layout,
bias: Optional[torch.Tensor] = None,
):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assert bias is None when layout.target != ATEN?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if bias is none and target layout.target != ATEN then execution will never reach this point. It will be executed via the "native" target which has bias assert. Please check here: https://github.com/pytorch/ao/pull/1447/files#diff-3e4ffa192fe9921999dd6a798fc3fa620377896ef9ba65245b1e5ab8c0d7d344R593


# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
# when AOTI supports int
Expand All @@ -136,6 +166,13 @@ def from_plain(
n_tensor = torch.empty(0, n, dtype=torch.int8)
k_tensor = torch.empty(0, k, dtype=torch.int8)

if layout.target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

if layout.has_weight_zeros:
args = [
int_data.to(torch.int8),
Expand Down Expand Up @@ -211,16 +248,13 @@ def __tensor_unflatten__(
def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None
bias is None or layout.target == Target.ATEN # Aten target allows bias
)


def _linear_impl(input_tensor, weight_tensor, bias):
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"

def _impl_2d(input_tensor, weight_tensor):
def _impl_2d_native(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

Expand Down Expand Up @@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor):
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
)(*args)

def _impl_2d_aten(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

m, k = input_tensor.shape
n, k_ = weight_tensor.shape
assert k_ == k
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

target = weight_tensor.tensor_impl.get_layout().target

if target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)

Expand All @@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor):
res = res.reshape(*lead_shape, m, n)
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
)


class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a subclass of AffineQuantizedTensor? What is wrong with using the existing to_affine_quantized_intx method?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suggested by one of the Pytorch reviewer: #1447 (comment)
from_plain method needs specialization to accomodate bias. AffineQuantizedTensor provides a generic flow where specialization was possible but Jerry suggested to seprate the flow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, (1) why adding bias as optional to_affine_quantized_intx isn't an option? or alternatively, (2) for native kernels we will add bias support sooner rather than later, so should we use this for both native and aten?

"""
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
"""

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
):
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert isinstance(
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)

to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
Loading