Skip to content

Commit

Permalink
[Feat]: Enable dyn_quant_pack_4bit aten kernels via Linear8BitActXBit…
Browse files Browse the repository at this point in the history
…WeightLayout

Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
  • Loading branch information
ng-05 committed Jan 11, 2025
1 parent 2a18e60 commit 358d6b4
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 11 deletions.
90 changes: 82 additions & 8 deletions torchao/experimental/_linear_8bit_act_xbit_weight_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
MappingType,
ZeroPointDomain,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand All @@ -40,13 +43,16 @@ class Target(Enum):

NATIVE = auto()
FALLBACK = auto()
ATEN = auto()


def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
elif target.lower() == "fallback":
return Target.FALLBACK
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

Expand All @@ -56,22 +62,27 @@ class Linear8BitActXBitWeightLayout(Layout):
nbit: int
group_size: int

# The target platform for the layout, either 'native' or 'fallback'.
# The target platform for the layout, 'native', 'fallback' or 'aten'
target: Target

# Allow bias access via layout
bias: Optional[torch.Tensor] = None

def __init__(
self,
nbit: int,
group_size: int,
target: str,
bias: Optional[torch.Tensor] = None,
):
assert nbit <= 8
self.nbit = nbit
self.group_size = group_size
self.target = target_from_str(target)
self.bias = bias

def extra_repr(self):
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}"
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}, bias={self.bias}"


def _pack_weights_native(
Expand All @@ -81,7 +92,6 @@ def _pack_weights_native(
layout: Layout,
):
assert isinstance(layout, Linear8BitActXBitWeightLayout)
assert layout.target == Target.NATIVE
nbit = layout.nbit
group_size = layout.group_size
has_weight_zeros = zero_point is not None
Expand All @@ -100,6 +110,12 @@ def _pack_weights_native(
torch.empty(0, group_size, dtype=torch.int8),
]

if TORCH_VERSION_AT_LEAST_2_6 and layout.target == Target.ATEN:
in_features = int_data.shape[-1]
out_features = int_data.shape[-2]
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
return torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, layout.bias, group_size, in_features, out_features)
wzp_suffix = "" if has_weight_zeros else "0zp"
return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")(
*args
Expand Down Expand Up @@ -153,7 +169,7 @@ def get_layout(self) -> Layout:
def get_plain(
self,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.get_layout().target == Target.FALLBACK:
if self.get_layout().target == Target.FALLBACK or self.get_layout().target == Target.ATEN:
return self.packed_weight, self.scale, self.zero_point
raise NotImplementedError(
"get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback"
Expand All @@ -170,12 +186,17 @@ def from_plain(
assert isinstance(layout, Linear8BitActXBitWeightLayout)

try:
if layout.target == Target.NATIVE:
if layout.target == Target.NATIVE or layout.target == Target.ATEN:
packed_weight = _pack_weights_native(
int_data, scale, zero_point, layout
)
scale = None
zero_point = None
# avoid storing bias tensor but indicate if Linear layer got bias on printing as
# 1. aten_dynamic_quant already packed it in weights or
# 2. its not needed by any other op
if layout.bias is not None:
layout.bias = True
return cls(packed_weight, scale, zero_point, layout)
except Exception as e:
logger.warning(
Expand Down Expand Up @@ -216,7 +237,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)

def __tensor_flatten__(self):
if self.get_layout().target == Target.NATIVE:
if self.get_layout().target == Target.NATIVE or self.get_layout().target == Target.ATEN:
return ["packed_weight"], [self.get_layout()]

# fallback
Expand All @@ -242,8 +263,11 @@ def _linear_int8_dynamic_activation_intx_weight_check(
input_tensor, weight_tensor, bias
):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None

target_condition = False
if isinstance(layout, Linear8BitActXBitWeightLayout) and layout.target == Target.ATEN:
target_condition = True
res = isinstance(layout, Linear8BitActXBitWeightLayout) and (bias is None or target_condition)
return res

def _linear_int8_dynamic_activation_intx_weight_fallback_impl(
input_tensor, weight_tensor, bias
Expand Down Expand Up @@ -353,6 +377,51 @@ def _impl_2d(input_tensor, weight_tensor):
return res


def _linear_int8_dynamic_activation_intx_weight_aten_impl(
input_tensor, weight_tensor, bias
):
assert weight_tensor.tensor_impl.get_layout().target == Target.ATEN
if weight_tensor.zero_point_domain != ZeroPointDomain.NONE:
raise NotImplementedError(
"MappingType.ASSYMETRIC in is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is aten"
)
assert (
weight_tensor.tensor_impl.get_layout().nbit == 4
), f"Only 4 bit is supported"
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
# aten supports bias for kleidiAI but not for default fallback op
if not torch.backends.kleidiai.is_available():
print("TODO bias == None")
assert bias == None

def _impl_2d(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

m, k = input_tensor.shape
n, k_ = weight_tensor.shape
assert k_ == k
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)

assert input_tensor.dim() >= 3
lead_shape = input_tensor.shape[0:-2]
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
n, k_ = weight_tensor.shape
assert k_ == k

res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
res = res.reshape(*lead_shape, m, n)
return res


def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias):
target = weight_tensor.tensor_impl.get_layout().target
if target == Target.NATIVE:
Expand All @@ -365,6 +434,11 @@ def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor
input_tensor, weight_tensor, bias
)

if target == Target.ATEN:
return _linear_int8_dynamic_activation_intx_weight_aten_impl(
input_tensor, weight_tensor, bias
)

assert False, f"Unknown target {target}"


Expand Down
110 changes: 109 additions & 1 deletion torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torchao.quantization.quant_api import (
MappingType,
)
import logging
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
from torch.ao.quantization.fx._decomposed import (
dequantize_per_channel_group,
quantize_per_channel_group,
)
from torchao.quantization.granularity import (
PerRow,
PerGroup,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -482,6 +492,104 @@ def quantize(self, model: nn.Module) -> nn.Module:
return model


_intx_granularity = Union[PerGroup, PerRow]


def int8_dynamic_activation_intx_weight_v2(
granularity: Optional[_intx_granularity] = PerGroup(32),
nbit: int = 4,
has_weight_zeros: bool = False,
target: str = "native",
mapping_type: MappingType = MappingType.ASYMMETRIC,
has_bias: bool = False,
):
from torchao.experimental._linear_8bit_act_xbit_weight_layout import (
Linear8BitActXBitWeightLayout,
)
from torchao.quantization.quant_api import (
ZeroPointDomain,
_get_linear_subclass_inserter,
to_affine_quantized_intx,
)

def get_quant_params(weight, has_weight_zeros: bool, mapping_type: MappingType, granularity: Optional[_intx_granularity]):
scale_dtype = None
zero_point_dtype = torch.int8
zero_point_domain = (
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE
)
target_dtype = torch.int32
preserve_zero = has_weight_zeros
if mapping_type == MappingType.ASYMMETRIC:
pass
elif mapping_type == MappingType.SYMMETRIC:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "MappingType.SYMMETRIC requires torch >= 2.6.0"
zero_point_dtype = torch.int8
zero_point_domain = ZeroPointDomain.NONE
preserve_zero = True
# The KleidiAI Groupwise kernel only supports bf16 scales for now
if torch.backends.kleidiai.is_available():
assert weight.dtype == torch.float32, f"Only float32 dtype is supported for KleidiAI int4 kernels. Provided {weight.dtype}"
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16
else:
raise ValueError(
f"Only mapping_type ASYMMETRIC, SYMMETRIC are supported. Provided {mapping_type}"
)

return target_dtype, zero_point_dtype, scale_dtype, preserve_zero, zero_point_domain

def apply(weight, bias: Optional[torch.Tensor] = None):
if isinstance(granularity, PerGroup):
group_size = granularity.group_size
elif isinstance(granularity, PerRow):
group_size = weight.shape[1]
else:
raise ValueError(
f"Only granularity PerGroup(), PerRow() are supported. Provided {granularity}"
)
assert weight.shape[-1] % group_size == 0
assert weight.device == torch.device("cpu"), "Only CPU is supported"
use_hqq = False
layout_args = [nbit, group_size, target]
if bias is not None:
layout_args.append(bias)
layout = Linear8BitActXBitWeightLayout(*layout_args)
# mapping_type = MappingType.ASYMMETRIC
eps = torch.finfo(torch.float32).eps
block_size = (1, group_size)
# target_dtype = torch.int32
quant_min = -(1 << (nbit - 1))
quant_max = (1 << (nbit - 1)) - 1
target_dtype, zero_point_dtype, scale_dtype, preserve_zero, zero_point_domain = get_quant_params(
weight, has_weight_zeros, mapping_type, granularity)
# Note: this works differently than other quantizers because the dynamic
# activation quantization is fused with the kernel/op (and static activation quantization
# is not supported).
return to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
_layout=layout,
use_hqq=use_hqq,
)

return _get_linear_subclass_inserter(
apply,
propagate_bias=has_bias
)


def int8_dynamic_activation_intx_weight(
group_size: int = 128,
nbit: int = 4,
Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,18 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs):
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs):
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
to the weight of linear module
"""

def insert_subclass(lin):
requires_grad = allow_requires_grad and lin.weight.requires_grad
args = [lin.weight]
if propagate_bias == True:
args.append(lin.bias)
lin.weight = torch.nn.Parameter(
constructor(lin.weight, **kwargs), requires_grad=requires_grad
constructor(*args, **kwargs), requires_grad=requires_grad
)
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
return lin
Expand Down

0 comments on commit 358d6b4

Please sign in to comment.