diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bcd8af7ad..1087db8cf 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -25,6 +25,9 @@ AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G32WeightOnlyQuantizedLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, @@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." ) - def test_autoquant_float(self): + def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 m, k, n = 128, 128, 128 example_input = torch.randn(m, k, device=device, dtype=dtype) - model = ( - torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k, n), - torch.nn.ReLU(), + for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) ) - .to(device) - .to(dtype) - ) - ref = model(example_input) - torchao.autoquant( - model, - qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, - ) - out = model(example_input) - from torchao.quantization.autoquant import ( - BFloat16Tensor, - Float16Tensor, - Float32Tensor, - ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + self.assertIn( + type(model[1].weight), + qtensor_class_list, + ) + self.assertGreater(compute_error(out, ref), 40) - self.assertIn( - type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor] - ) - print(compute_error(out, ref)) - self.assertGreater(compute_error(out, ref), 60) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) + @unittest.skipIf(not has_gemlite, "gemlite not available") + def test_autoquant_int4wo(self, device, dtype): + if device == "cpu": + self.skipTest(f"int4wo is for cuda, not {device}") + + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + + for qclass in [ + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G32WeightOnlyQuantizedLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + ]: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + + self.assertIn(type(model[1].weight), qtensor_class_list) + self.assertGreater(compute_error(ref, out), 20) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) + def test_autoquant_float8(self, device, dtype): + if device == "cpu": + self.skipTest(f"int4wo is for cuda, not {device}") + + # note: marlin sparse layout failed when scale_t has a dimension of 1d + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + + for qclass in [ + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, + AQFloat8WeightOnlyQuantizedLinearWeight, + ]: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + + self.assertIn(type(model[1].weight), qtensor_class_list) + self.assertGreater(compute_error(ref, out), 20) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index e37623182..2a84dd181 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -227,6 +227,11 @@ def from_plain( # Linear layers are (in_features, out_features) but the int_data that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. q_w_24 = int_data.t() + # addressing the case when scale has dimension 1, happens when + # weight_shape[-1] == group_size == 128 + if scale.ndim == 1: + scale = scale.reshape(scale.shape[0], -1) + scale_t = scale.t() if not torch.cuda.get_device_capability()[0] >= 8: diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d506d2b65..d49e84e06 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -16,6 +16,7 @@ from torchao.kernel import safe_int_mm from torchao.quantization.linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, + to_linear_activation_quantized, ) from torchao.quantization.quant_primitives import ( MappingType, @@ -370,6 +371,18 @@ def _is_interpolate_mode(mode): return False +def _to_float16(x: torch.Tensor) -> torch.Tensor: + return x.to(torch.float16) + + +def _to_bfloat16(x: torch.Tensor) -> torch.Tensor: + return x.to(torch.bfloat16) + + +def _identity(x: torch.Tensor) -> torch.Tensor: + return x + + class AQMixin: """ Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias. @@ -610,9 +623,11 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): return y -class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQInt4G32WeightOnlyQuantizedLinearWeight( + LinearActivationQuantizedTensor, AQMixin +): """ - AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + AutoQuantizable version of int4_weight_only """ group_size: int = 32 @@ -621,20 +636,30 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): @classmethod def from_float(cls, weight): + from torchao.dtypes import to_affine_quantized_intx + group_size = cls.group_size _layout = cls.aq_layout if weight.shape[-1] % group_size != 0: return weight + input_quant_func = None + + # NOTE: we only convert activation dtype and weight dtype here + # because the kernel implementation for both TensorCoreTiledLayout and MarlinSparseLayout + # can work with multiple bias dtypes (by converting bias to the dtype of activation) if ( isinstance(_layout, TensorCoreTiledLayout) and weight.dtype != torch.bfloat16 ): - return weight - - if isinstance(_layout, MarlinSparseLayout) and weight.dtype != torch.float16: - return weight + weight = weight.to(torch.bfloat16) + input_quant_func = _to_bfloat16 + elif isinstance(_layout, MarlinSparseLayout) and weight.dtype != torch.float16: + weight = weight.to(torch.float16) + input_quant_func = _to_float16 + else: + input_quant_func = _identity use_hqq = True mapping_type = MappingType.ASYMMETRIC @@ -653,7 +678,7 @@ def from_float(cls, weight): zero_point_domain = ZeroPointDomain.INT use_hqq = False - return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( + weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -668,6 +693,10 @@ def from_float(cls, weight): use_hqq=use_hqq, ) + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float( + weight, input_quant_func + ) + class AQInt4G64WeightOnlyQuantizedLinearWeight( AQInt4G32WeightOnlyQuantizedLinearWeight @@ -694,16 +723,19 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( aq_layout: Layout = MarlinSparseLayout() -class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight( + LinearActivationQuantizedTensor, AQMixin +): group_size: int = 32 @classmethod def from_float(cls, weight): - if weight.dtype != torch.float16: - return weight - + from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + if weight.dtype != torch.float16: + weight = weight.to(torch.float16) + bit_width = 4 packing_bitwidth = 32 contiguous = None @@ -711,9 +743,12 @@ def from_float(cls, weight): aqt_kwargs = get_gemlite_aqt_kwargs( weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq ) - return super( - AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls - ).from_hp_to_intx(weight, **aqt_kwargs) + weight = to_affine_quantized_intx(weight, **aqt_kwargs) + input_quant_func = _to_float16 + + return super(AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float( + weight, input_quant_func + ) class AQGemliteInt4G64WeightOnlyQuantizedLinearWeight( @@ -755,11 +790,24 @@ def from_float(cls, weight): return weight +# TODO: remove skip_weight_conversion arg class Float32Tensor(TorchAOBaseTensor): """Tensor subclass tensor for fp32 dtype""" - def __init__(self, weight): - self.weight = weight.to(torch.float32) + @staticmethod + def __new__(cls, weight, skip_weight_conversion=False): + kwargs = {} + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = weight.dtype + kwargs["requires_grad"] = False + shape = weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.float32) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -778,7 +826,7 @@ def _apply_fn_to_data(self, fn): @classmethod def from_float(cls, weight): - return Float32Tensor(weight) + return cls(weight) @Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) @@ -816,8 +864,8 @@ def _(func, types, args, kwargs): class BFloat16Tensor(Float32Tensor): - def __init__(self, weight): - self.weight = weight.to(torch.bfloat16) + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.bfloat16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -830,13 +878,13 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): ).to(dtype=orig_dtype) @classmethod - def from_float(cls, weight): - return BFloat16Tensor(weight) + def from_float(cls, weight, skip_weight_conversion=False): + return cls(weight, skip_weight_conversion) class Float16Tensor(Float32Tensor): - def __init__(self, weight): - self.weight = weight.to(torch.float16) + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.float16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -849,8 +897,8 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): ).to(dtype=orig_dtype) @classmethod - def from_float(cls, weight): - return Float16Tensor(weight) + def from_float(cls, weight, skip_weight_conversion=False): + return cls(weight, skip_weight_conversion) class AQFloat32LinearWeight(Float32Tensor, AQMixin): @@ -911,9 +959,7 @@ def from_float(cls, weight): ) -class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight( - AQMixin, LinearActivationQuantizedTensor -): +class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Tensor): """ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling """ @@ -942,12 +988,13 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) # TODO: make this serializable - input_quant_func = lambda x: _input_activation_quant_func_fp8( - x=x, - activation_granularity=cls.activation_granularity, - activation_dtype=input_target_dtype, - ) + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": cls.activation_granularity, + "activation_dtype": input_target_dtype, + } block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -955,10 +1002,15 @@ def get_per_token_block_size(x): _layout=_layout, scale_dtype=torch.float32, ) - weight = super( + weight = to_linear_activation_quantized( + weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + # at inference time, + # we first convert the input, weight and bias to bfloat16, and then quantize activation + # and then dispatch to the quantized ops + return super( AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls - ).from_float(weight, input_quant_func) - return weight + ).from_float(weight, skip_weight_conversion=True) class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @@ -982,15 +1034,14 @@ def get_weight_block_size(x): return x.shape target_dtype = torch.float8_e4m3fn - input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) - # TODO: make this serializable - input_quant_func = lambda x: _input_activation_quant_func_fp8( - x=x, - activation_granularity=cls.activation_granularity, - activation_dtype=input_target_dtype, - ) + # TODO: test serializable + input_quant_func = _input_activation_quant_func_fp8 + input_quant_args = { + "activation_granularity": cls.activation_granularity, + "activation_dtype": input_target_dtype, + } block_size = get_weight_block_size(weight) weight = to_affine_quantized_floatx( input_float=weight, @@ -1001,7 +1052,7 @@ def get_weight_block_size(x): ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls - ).from_float(weight, input_quant_func) + ).from_float(weight, input_quant_func, input_quant_args) return weight @@ -1299,3 +1350,10 @@ def finalize_autoquant(): if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) + torch.serialization.add_safe_globals( + [ + _to_float16, + _to_bfloat16, + _identity, + ] + )