Skip to content

Commit

Permalink
[Fix]: Use Custom Tensor for dyn_quant_matmul_4bit aten op
Browse files Browse the repository at this point in the history
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
  • Loading branch information
ng-05 committed Jan 17, 2025
1 parent 983215d commit bac01a9
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
7 changes: 1 addition & 6 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
):
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
Expand Down Expand Up @@ -279,11 +278,7 @@ def from_hp_to_intx(

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
args = [data, scale, zero_point, _layout]
# Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias
if bias is not None:
args.append(bias)
tensor_impl = tensor_impl_ctr(*args)
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
return cls(
tensor_impl,
block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
Expand All @@ -20,7 +22,11 @@
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
MappingType,
choose_qparams_affine,
quantize_affine,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)
Expand Down Expand Up @@ -325,3 +331,78 @@ def _impl_2d_aten(input_tensor, weight_tensor):
_linear_check,
_linear_impl,
)


class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
"""
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
"""

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
):
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert isinstance(
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)

to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
44 changes: 23 additions & 21 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def quantize(self, model: nn.Module) -> nn.Module:

from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
)
from torchao.quantization.linear_activation_quantized_tensor import (
to_linear_activation_quantized,
Expand Down Expand Up @@ -576,6 +577,7 @@ def int8_dynamic_activation_intx_weight(
)
bit_width = dtype_to_bit_width[weight_dtype]
layout_arg = layout
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"

def apply(weight, bias: Optional[torch.Tensor] = None):
if isinstance(granularity, PerGroup):
Expand All @@ -591,6 +593,10 @@ def apply(weight, bias: Optional[torch.Tensor] = None):

layout = layout_arg
scale_dtype = None
tensor_quantizer = to_affine_quantized_intx
quant_min = -(1 << (bit_width - 1))
quant_max = (1 << (bit_width - 1)) - 1

if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
assert (
weight.device == torch.device("cpu")
Expand All @@ -613,26 +619,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype

quant_min = -(1 << (bit_width - 1))
quant_max = (1 << (bit_width - 1)) - 1
weight = to_affine_quantized_intx(
weight,
mapping_type=weight_mapping_type,
block_size=(1, group_size),
target_dtype=torch.int32,
quant_min=quant_min,
quant_max=quant_max,
eps=torch.finfo(torch.float32).eps,
scale_dtype=scale_dtype,
zero_point_dtype=torch.int8,
preserve_zero=has_weight_zeros,
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
_layout=layout,
bias=bias
)
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx

quantizer_args = [weight,
weight_mapping_type,
(1, group_size),
torch.int32,
quant_min,
quant_max,
torch.finfo(torch.float32).eps,
scale_dtype,
torch.int8,
has_weight_zeros,
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
layout,
False] + ([bias] if propagate_bias else [])

weight = tensor_quantizer(*quantizer_args)

# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
# with the kernel and it should not be applied separately
Expand All @@ -650,7 +653,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
weight = to_linear_activation_quantized(weight, activation_quant_func)
return weight

propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)


Expand Down
5 changes: 2 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,10 @@ def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, pro

def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
args = [lin.weight]
if propagate_bias == True:
args.append(lin.bias)
kwargs["bias"] = lin.bias
lin.weight = torch.nn.Parameter(
constructor(*args, **kwargs), requires_grad=requires_grad
constructor(lin.weight, **kwargs), requires_grad=requires_grad
)
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
return lin
Expand Down

0 comments on commit bac01a9

Please sign in to comment.