diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d7cd517650..e22bcb1253 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,5 @@ from .nf4tensor import NF4Tensor, to_nf4 + # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uintx import UInt4Tensor from .affine_quantized_tensor import ( @@ -9,12 +10,11 @@ to_affine_quantized_fpx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, - PlainAQTTensorImpl, ) -from .affine_quantized_tensor_ops import * + +from . import affine_quantized_tensor_ops from .utils import ( Layout, - MarlinSparseLayout, PlainLayout, ) from .floatx import ( @@ -22,10 +22,20 @@ Float8AQTTensorImpl, ) from .uintx import ( + UintxTensor, + UintxLayout, + UintxAQTTensorImpl, + to_uintx, + _DTYPE_TO_BIT_WIDTH, + _BIT_WIDTH_TO_DTYPE, + UInt4Tensor, SemiSparseLayout, TensorCoreTiledLayout, MarlinSparseLayout, + PlainAQTTensorImpl, + BlockSparseLayout, ) + __all__ = [ "NF4Tensor", "to_nf4", @@ -43,4 +53,15 @@ "Float8Layout", "Float8AQTTensorImpl", "MarlinSparseLayout", + "PlainAQTTensorImpl", + "affine_quantized_tensor_ops", + "BlockSparseLayout", + "to_uintx", + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "_DTYPE_TO_BIT_WIDTH", + "_BIT_WIDTH_TO_DTYPE", + "Uint4Tensor", + "PlainAQTTensorImpl", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index cb2527076d..54c0c5a9c7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,10 +1,8 @@ -from dataclasses import dataclass import logging import math from typing import Optional, Tuple, Union import torch -from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import Layout, PlainLayout from torchao.quantization.quant_primitives import ( FP8_TYPES, @@ -29,6 +27,7 @@ logger = logging.getLogger(__name__) aten = torch.ops.aten + ############################## # Tensor Subclass Definition # ############################## @@ -445,151 +444,6 @@ def _apply_fn_to_data(self, fn): register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor - -@register_layout(PlainLayout) -class PlainAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point - tensors directly as plain tensors. - - fields: - int_data (torch.Tensor): the quantized integer data Tensor - scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor - zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = ( - tensor_data_dict["int_data"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - (_layout,) = tensor_attributes - return cls(int_data, scale, zero_point, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), - self._layout, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - elif func is aten.t.default: - tensor = args[0] - new = tensor.__class__( - tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif dim == 1: - assert ( - len(self.scale.shape) == 1 - ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl( - aten.slice.Tensor(self.int_data, dim, start, end, step), - self.scale.view(-1), - self.zero_point.view(-1), - self._layout, - ) - else: - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data, self.scale, self.zero_point - - def get_layout(self) -> Layout: - return self._layout - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, PlainLayout) - return cls(int_data, scale, zero_point, _layout) - - ##################################################### # torch functional and aten operator implementation # ##################################################### diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a3d995d653..ea62a77065 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -1,9 +1,12 @@ import logging -from typing import Optional, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.affine_quantized_tensor import * +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + dequantize_affine, +) +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl # from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, @@ -31,7 +34,7 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.dtypes.uintx.uint8_layout import ( +from torchao.dtypes.uintx.plain_layout import ( _linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl, _linear_fp_act_int8_weight_check, @@ -71,11 +74,14 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: - logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + logger.warn( + f"Attempting to remove non-existant dispatch condition {dispatch_condition}" + ) class QuantizedLinearNotImplementedError(NotImplementedError): - """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" + pass @@ -84,14 +90,15 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") + raise QuantizedLinearNotImplementedError( + "No specialized dispatch found for quantized linear op" + ) # Attach the _quantized_linear_op to the AffineQuantizedTensor class AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op - # # following are a list of (dispatch_condition, implementation) functions that takes the following args: # # input_tensor: dimension is (M1, M2, ..., in_features) # # weight_tensor: dimension is (out_features, in_features) @@ -100,14 +107,26 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), + ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, + ), + ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, + ), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), - (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, + ), + ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) @@ -125,7 +144,9 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -134,7 +155,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -148,19 +173,31 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): # new_arg1 = args[1].dequantize() # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) idx = args[0] int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so # we need to increase block size to correct dim - new_blocks = idx.dim()-1 + new_blocks = idx.dim() - 1 return dequantize_affine( sliced_data, - new_blocks*[1]+list(args[1].block_size), + new_blocks * [1] + list(args[1].block_size), sliced_scale, sliced_zero_point, sliced_data.dtype, @@ -179,7 +216,9 @@ def _(func, types, args, kwargs): args[0], ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -189,7 +228,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -201,20 +244,22 @@ def _(func, types, args, kwargs): @implements(aten.mm.default) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None - ) + input_tensor, weight_tensor, bias = (args[0], args[1], None) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -256,7 +301,14 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.tensor_impl.t(), + transposed_block_size, + shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -271,11 +323,22 @@ def _(func, types, args, kwargs): shape = list(self.shape) shape[dim] = end - start block_size = self.block_size - assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + assert ( + len(block_size) == 2 + ), f"Slice only works for 2d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__( + aten.slice.Tensor(self.tensor_impl, dim, start, end, step), + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -285,11 +348,31 @@ def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__( + self.tensor_impl, + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - - raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") + return self.__class__( + self.tensor_impl, + block_size, + (self.numel(),), + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + raise ValueError( + f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 34b7ec1f91..ddfa9e3669 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,2 +1,18 @@ -from .floatx_tensor_core_layout import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx_tensor_core_layout import ( + FloatxTensorCoreLayout, + FloatxTensorCoreAQTTensorImpl, + to_scaled_tc_floatx, + from_scaled_tc_floatx, + _SPLIT_K_MAP, +) from .float8_layout import Float8AQTTensorImpl, Float8Layout + +__all__ = [ + "FloatxTensorCoreLayout", + "FloatxTensorCoreAQTTensorImpl", + "to_scaled_tc_floatx", + "from_scaled_tc_floatx", + "_SPLIT_K_MAP", + "Float8AQTTensorImpl", + "Float8Layout", +] diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 1c3c046497..bf3f96dca3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -1,9 +1,9 @@ import torch -from torchao.utils import _is_float8_type -from torchao.dtypes.utils import Layout, AQTTensorImpl +from torchao.utils import _is_float8_type, fill_defaults +from torchao.dtypes.utils import Layout, AQTTensorImpl, get_out_shape from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - register_layout + register_layout, ) from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -11,12 +11,13 @@ preprocess_data, Float8MMConfig, addmm_float8_unwrapped_inference, - _is_rowwise_scaled + _is_rowwise_scaled, ) from torch.utils._python_dispatch import ( return_and_correct_aliasing, is_traceable_wrapper_subclass, ) + aten = torch.ops.aten @@ -24,6 +25,7 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None + @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ @@ -32,6 +34,7 @@ class Float8AQTTensorImpl(AQTTensorImpl): Note: technically we should not create a new layout for float8 we should merge this into plain layout """ + float8_data: torch.Tensor scale: torch.Tensor transposed: bool @@ -66,7 +69,7 @@ def __init__( self._layout = _layout def _apply_fn_to_data(self, fn): - """ Applys a fn to all tensor components stored on this class""" + """Applys a fn to all tensor components stored on this class""" return self.__class__( fn(self.float8_data), fn(self.scale), @@ -91,7 +94,10 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, _layout, = tensor_attributes + ( + transposed, + _layout, + ) = tensor_attributes return cls(float8_data, scale, transposed, _layout) @classmethod @@ -115,23 +121,50 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: - #TODO: scale replecation should be dependent on block size + # TODO: scale replecation should be dependent on block size if self.scale.ndim == 1: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), ) elif self.scale.ndim == 0: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor(self.float8_data, dim, start, end, step), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" + ) elif dim == 1: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous(), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -153,42 +186,50 @@ def from_plain( zero_point: Optional[torch.Tensor], _layout: Layout, ): - """ Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + """Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type( + data.dtype + ), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance( + _layout, Float8Layout + ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" return cls(data, scale, False, _layout) def __repr__(self): float8_data, scale, _ = self.get_plain() _layout = self.get_layout() - return (f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})") + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})" + ) ########################## # Float8 Dispatch Kernels ########################## + def _linear_fp8_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], - weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt._layout, Float8Layout) + isinstance(aqt, AffineQuantizedTensor) + and isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) + return check_aqt(input_tensor) and check_aqt(weight_tensor) def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """ Ensures input tensor is correctly formated for _scaled_mm """ + """Ensures input tensor is correctly formated for _scaled_mm""" input_scale = input_scale.unsqueeze(-1) if input_scale.dim() > 2: @@ -196,9 +237,10 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale + def _linear_fp8_act_fp8_weight_impl( - input_tensor: 'AffineQuantizedTensor', - weight_tensor: 'AffineQuantizedTensor', + input_tensor: "AffineQuantizedTensor", + weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" @@ -219,7 +261,9 @@ def _linear_fp8_act_fp8_weight_impl( # Handle rowwise case if _is_rowwise_scaled(weight_tensor): - assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + assert _is_rowwise_scaled( + input_tensor + ), "Input tensor must be rowwise block size" w_scale = w_scale.unsqueeze(-1).T input_scale = preprocess_scale(input_scale, input_tensor.shape) @@ -237,25 +281,31 @@ def _linear_fp8_act_fp8_weight_impl( use_fast_accum=scaled_mm_config.use_fast_accum, ).reshape(out_shape) + def _linear_fp_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], - weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, Float8Layout) + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) + and ( + weight_tensor.shape == weight_tensor.block_size + or _is_rowwise_scaled(weight_tensor) + ) ) + def _linear_fp_act_fp8_weight_impl( input_tensor: torch.Tensor, - weight_tensor: 'AffineQuantizedTensor', + weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index cfdb566279..b23010878e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -7,7 +7,11 @@ return_and_correct_aliasing, is_traceable_wrapper_subclass, ) -from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, +) from torchao.dtypes.utils import ( Layout, AQTTensorImpl, @@ -24,11 +28,23 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -42,8 +58,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -52,8 +100,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, - 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -89,8 +169,12 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -124,7 +208,9 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -132,7 +218,9 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) dtype = tensor.dtype tensor = tensor.float() @@ -159,8 +247,10 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -231,7 +321,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7 + 57344: 7, }, { # tokens: [65:128] 3072: 9, @@ -242,7 +332,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6 + 57344: 6, }, { # tokens: [129:192] 3072: 6, @@ -253,7 +343,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4 + 57344: 4, }, { # tokens: [193:256] 3072: 9, @@ -264,7 +354,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4 + 57344: 4, }, { # tokens: [257:320] 3072: 7, @@ -275,7 +365,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4 + 57344: 4, }, { # tokens: [321:384] 3072: 3, @@ -286,7 +376,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3 + 57344: 3, }, { # tokens: [385:448] 3072: 5, @@ -297,7 +387,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3 + 57344: 3, }, { # tokens: [449:512] 3072: 2, @@ -308,7 +398,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1 + 57344: 1, }, { # tokens: [513:576] 3072: 2, @@ -319,7 +409,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [577:640] 3072: 5, @@ -330,7 +420,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [641:704] 3072: 3, @@ -341,7 +431,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [705:768] 3072: 3, @@ -352,17 +442,18 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 - } + 57344: 1, + }, ] # quantization api integrations + @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl - """ + """Layout type for FloatxTensorCoreAQTTensorImpl""" + ebits: int mbits: int @@ -390,6 +481,7 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ + def __new__( cls, packed_floatx_data: torch.Tensor, @@ -398,11 +490,16 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -425,12 +522,17 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] - _layout, = tensor_attributes + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) return unpacked_floatx_data, self.scale @classmethod @@ -449,7 +551,9 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -487,7 +591,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), ) raise NotImplementedError( @@ -502,28 +611,28 @@ def get_layout(self) -> Layout: def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import FloatxTensorCoreLayout + return ( # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype in (torch.float16, torch.bfloat16) and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and - ( + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 3) or + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 1) + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) ) ) + def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 7b2c4e9028..e9eca3a011 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,30 @@ -from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .uintx_layout import ( + UintxTensor, + UintxLayout, + UintxAQTTensorImpl, + to_uintx, + _DTYPE_TO_BIT_WIDTH, + _BIT_WIDTH_TO_DTYPE, +) from .uint4 import UInt4Tensor from .block_sparse_layout import BlockSparseLayout from .semi_sparse_layout import SemiSparseLayout from .marlin_sparse_layout import MarlinSparseLayout from .tensor_core_tiled_layout import TensorCoreTiledLayout +from .plain_layout import PlainAQTTensorImpl + + +__all__ = [ + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "to_uintx", + "UInt4Tensor", + "BlockSparseLayout", + "SemiSparseLayout", + "MarlinSparseLayout", + "TensorCoreTiledLayout", + "_DTYPE_TO_BIT_WIDTH", + "_BIT_WIDTH_TO_DTYPE", + "PlainAQTTensorImpl", +] diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 4f6358fae5..8355149cf1 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -13,18 +13,20 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, - PlainAQTTensorImpl ) -from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range logger = logging.getLogger(__name__) aten = torch.ops.aten + @dataclass(frozen=True) class BlockSparseLayout(Layout): blocksize: int = 64 + @register_layout(BlockSparseLayout) class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] @@ -33,7 +35,13 @@ class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] @staticmethod def __new__( # noqa: PYI034 @@ -115,17 +123,23 @@ def from_plain(cls, int_data, scale, zero_point, _layout): bsr_values=bsr_tensor.values(), scale=scale, zero_point=zero_point, - _layout = _layout, + _layout=_layout, requires_grad=False, ) def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) return int_data_expanded, self.scale, self.zero_point def _apply_fn_to_data(self, func): return self.__class__( - shape = self.shape, + shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), @@ -166,16 +180,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) - def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, BlockSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) ) @@ -187,12 +200,14 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() - y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1)) + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) y = y.reshape(*y_shape) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 5483e4d7a3..cac2c70f5c 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -2,24 +2,33 @@ from torchao.dtypes.utils import Layout, AQTTensorImpl from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - register_layout + register_layout, ) import torch from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4 +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) + +aten = torch.ops.aten def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - input_tensor.dtype == torch.float16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, MarlinSparseLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and input_tensor.dtype == torch.float16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, MarlinSparseLayout) ) + def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace, const + from torchao.sparsity.marlin import marlin_24_workspace from torchao.ops import marlin_24_gemm assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -39,8 +48,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b workspace_24 = marlin_24_workspace(original_shape[1]) out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, - workspace_24, num_bits, size_m, size_n, size_k + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, ) # Unfold the batch dimension @@ -50,9 +66,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b out += bias.to(out.dtype) return out + @dataclass(frozen=True) class MarlinSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1ยบ: the input tensor is transposed since the linear layer keeps the weights in a transposed format @@ -66,10 +82,12 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: the preprocessed tensor """ from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() w_24, _ = inject_24(input_t, *input_t.shape) return w_24.t() + @register_layout(MarlinSparseLayout) class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ @@ -88,6 +106,7 @@ class MarlinSparseAQTTensorImpl(AQTTensorImpl): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ + @staticmethod def __new__( cls, @@ -144,7 +163,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] + return ["int_data", "scale", "zero_point", "meta"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] @classmethod def __tensor_unflatten__( @@ -155,10 +179,22 @@ def __tensor_unflatten__( zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] _layout, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) + return cls( + int_data, + scale, + zero_point, + meta, + _layout, + original_shape, + group_size, + num_bits, + ) def get_plain(self): - from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + from torchao.sparsity.marlin import ( + unpack_from_marlin_24, + ) # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, self.scale, @@ -179,7 +215,11 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + from torchao.sparsity.marlin import ( + pack_to_marlin_24, + const, + ) # avoid circular import + assert isinstance(_layout, MarlinSparseLayout) # Linear layers are (in_features, out_features) but the int_data that is reaching this point @@ -189,7 +229,7 @@ def from_plain( if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." ) if q_w_24.dtype != torch.int32: @@ -206,14 +246,14 @@ def from_plain( # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main num_bits = 4 if torch.max(q_w_24) < 16 else -1 if num_bits not in [4]: - raise ValueError( - f"Only {[4]} bits are supported, got {num_bits}." - ) + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert group_size <= in_features, "Group size must be less than or equal to in_features." + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -221,12 +261,19 @@ def from_plain( ) # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, _layout, q_w_24.shape, - group_size, num_bits + marlin_24_q_w_comp, + marlin_24_s, + zero_point, + meta, + _layout, + q_w_24.shape, + group_size, + num_bits, ) def get_layout(self) -> Layout: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py new file mode 100644 index 0000000000..9ce448edaa --- /dev/null +++ b/torchao/dtypes/uintx/plain_layout.py @@ -0,0 +1,264 @@ +import torch +from torchao.dtypes.utils import PlainLayout, AQTTensorImpl, Layout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.utils import fill_defaults +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) +from torchao.quantization.quant_primitives import ( + int_scaled_matmul, + ZeroPointDomain, +) +from typing import Optional, Tuple + +aten = torch.ops.aten + + +@register_layout(PlainLayout) +class PlainAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + tensors directly as plain tensors. + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (_layout,) = tensor_attributes + return cls(int_data, scale, zero_point, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + elif func is aten.t.default: + tensor = args[0] + new = tensor.__class__( + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif dim == 1: + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1), + self._layout, + ) + else: + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data, self.scale, self.zero_point + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, PlainLayout) + return cls(int_data, scale, zero_point, _layout) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and (aqt.quant_min is None or aqt.quant_min == -128) + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and + # weight is int8 per channel quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) + and len(weight_tensor.shape) == 2 + and len(weight_tensor.block_size) == 2 + and weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # TODO: enable cpu and mps efficient path + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + + # per channel int8 weight only quantizated mm + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias.to(m.dtype) + return y + + +def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index 31252701b5..1ac66d4fb2 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -3,25 +3,35 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, - PlainAQTTensorImpl ) import torch from typing import Optional -from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range +from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl + +aten = torch.ops.aten -def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): + +def _linear_int8_act_int8_weight_semi_structured_sparse_check( + input_tensor, weight_tensor, bias +): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, SemiSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, SemiSparseLayout) ) -def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): +def _linear_int8_act_int8_weight_semi_structured_sparse_impl( + input_tensor, weight_tensor, bias +): x_vals_int8 = input_tensor.tensor_impl.int_data x_scales = input_tensor.tensor_impl.scale w_vals_int8 = weight_tensor.tensor_impl.int_data @@ -29,7 +39,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, + w_vals_int8, + tmp.t(), + alpha=w_scales.to(torch.float32), + out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -41,9 +54,9 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y + @dataclass(frozen=True) class SemiSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() @@ -52,12 +65,12 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return temp - @register_layout(SemiSparseLayout) class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor """ + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -75,10 +88,10 @@ def get_plain(self): # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(cols, - dtype=self.int_data.dtype, - device=self.int_data.device).t()) + int_data_expanded = torch._cslt_sparse_mm( + self.int_data, + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), + ) return int_data_expanded, self.scale, self.zero_point @classmethod diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 1f6bb92179..f6dfb9a4d2 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -1,45 +1,51 @@ import torch -from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5 -from torchao.dtypes.utils import Layout, AQTTensorImpl -from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor, register_layout +from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5, fill_defaults +from torchao.dtypes.utils import Layout, AQTTensorImpl, is_device +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) from dataclasses import dataclass from typing import Optional, Tuple from torch.utils._python_dispatch import ( return_and_correct_aliasing, is_traceable_wrapper_subclass, ) -from torchao.quantization.quant_primitives import ( - ZeroPointDomain -) +from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params aten = torch.ops.aten + def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.tensor_impl.dtype == torch.int32 and - aqt.quant_min == 0 and - aqt.quant_max == 15 + aqt.tensor_impl.dtype == torch.int32 + and aqt.quant_min == 0 + and aqt.quant_max == 15 ) + def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.dtype == torch.bfloat16 and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - weight_tensor.dtype == torch.bfloat16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor._layout, TensorCoreTiledLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, TensorCoreTiledLayout) ) def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" assert input_tensor.shape[-1] == weight_tensor.shape[1], ( f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " @@ -63,24 +69,27 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: y += bias return y.to(orig_dtype) + @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): """ inner_k_tiles is an internal argument for packing function of tensor core tiled layout that can affect the performance of the matmul kernel """ + inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -93,14 +102,25 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: ) return input - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input = self.pre_process(input) orig_qparam_shape = scale.shape - new_qparam_shape, reduction_dims = _get_reduction_params(block_size, input.size()) + new_qparam_shape, reduction_dims = _get_reduction_params( + block_size, input.size() + ) for dim in reduction_dims: new_qparam_shape.pop(dim) - change_in_qparam_shape = [new_dim_size-orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape)] - padding_changes=[] + change_in_qparam_shape = [ + new_dim_size - orig_dim_size + for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) + ] + padding_changes = [] for dim_change in change_in_qparam_shape: padding_changes = [0, dim_change] + padding_changes scale = torch.nn.functional.pad(scale, padding_changes) @@ -155,7 +175,9 @@ def __new__( kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False @@ -181,8 +203,14 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, _layout, = tensor_attributes + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, _layout) @classmethod @@ -191,20 +219,26 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - _layout: Layout + _layout: Layout, ): - assert isinstance(_layout, TensorCoreTiledLayout) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, _layout) @@ -215,7 +249,9 @@ def to(self, *args, **kwargs): # between these two devices, in the future we should not use the same layout for # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 if not is_device(torch.device(self.device).type, device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") + raise ValueError( + f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" + ) return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -252,7 +288,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ - transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) + transposed = TensorCoreTiledAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) return return_and_correct_aliasing(func, args, kwargs, transposed) if func is aten.slice.Tensor: @@ -277,11 +318,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # this is to handle padding int_data = self._layout.post_process(int_data) scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) sliced = self.from_plain(int_data, scale, zero_point, self._layout) return sliced else: - raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -295,6 +340,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -311,12 +357,26 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) dequantized = dequantized.t().contiguous() # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) return int_data, scale, zero def get_layout(self) -> Layout: diff --git a/torchao/dtypes/uintx/uint8.py b/torchao/dtypes/uintx/uint8.py deleted file mode 100644 index 8d53e93e74..0000000000 --- a/torchao/dtypes/uintx/uint8.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -from torchao.dtypes.utils import PlainLayout -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from dataclasses import dataclass -from typing import Optional, Tuple, Union -from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - _is_rowwise_scaled -) -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, - is_traceable_wrapper_subclass, -) - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y diff --git a/torchao/dtypes/uintx/uint8_layout.py b/torchao/dtypes/uintx/uint8_layout.py deleted file mode 100644 index 6be1bc25ee..0000000000 --- a/torchao/dtypes/uintx/uint8_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -from torchao.dtypes.utils import PlainLayout -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from dataclasses import dataclass -from typing import Optional, Tuple, Union -from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - _is_rowwise_scaled -) -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, - is_traceable_wrapper_subclass, -) -from torchao.quantization.quant_primitives import ( - int_scaled_matmul, - ZeroPointDomain, -) - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx_layout.py similarity index 97% rename from torchao/dtypes/uintx/uintx.py rename to torchao/dtypes/uintx/uintx_layout.py index b47862a7e1..11bf2f88c9 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -3,16 +3,13 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout +from .bitpacking import pack, unpack +from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl from torchao.dtypes.utils import ( Layout, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import ( - register_layout, - PlainAQTTensorImpl -) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 1704fdb61f..38976176d8 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,7 @@ from dataclasses import dataclass from torchao.utils import TorchAOBaseTensor +import torch +from typing import Tuple, Union """ Base class for different layout, following the same design of PyTorch layout @@ -82,6 +84,7 @@ class AQTTensorImpl(TorchAOBaseTensor): Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct the underlying implementation of a AQT based on layout """ + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the tensor impl @@ -101,7 +104,7 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - """ Construct a TensorImpl from data, scale, zero_point and the _layout""" + """Construct a TensorImpl from data, scale, zero_point and the _layout""" pass def __repr__(self): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e021556ed3..611d11287e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -23,7 +23,7 @@ from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types -from torchao.dtypes.uintx.uintx import UintxLayout +from torchao.dtypes.uintx.uintx_layout import UintxLayout from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_floatx,