Skip to content

Commit

Permalink
Relax dtype requirements for int4 and float8 quants in autoquant (#1571)
Browse files Browse the repository at this point in the history
* Relax dtype requirements for int4 quants in autoquant

Summary:
Some of the int4 quant only works with bfloat16/float16, previously we require
the model to be in correct dtype to apply these in autoquant, this PR relaxes the constraints by
converting weight and activation to compatible dtypes

Test Plan:
python test/integration/test_integration.py -k test_autoquant_int4wo

Reviewers:

Subscribers:

Tasks:

Tags:

* remove prints

* add float8

* run pre-commit

* run pre-commit

* manual format

* enable bias=True test

* remove print
  • Loading branch information
jerryzh168 authored Jan 17, 2025
1 parent f520c91 commit cf45336
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 69 deletions.
125 changes: 100 additions & 25 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
AQInt4G32WeightOnlyQuantizedLinearWeight,
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight,
AQInt8WeightOnlyQuantizedLinearWeight2,
Expand Down Expand Up @@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
)
def test_autoquant_float(self):
def test_autoquant_hp_float(self):
device = "cuda"
dtype = torch.float32
m, k, n = 128, 128, 128
example_input = torch.randn(m, k, device=device, dtype=dtype)
model = (
torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k, n),
torch.nn.ReLU(),
for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST:
model = (
torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k, n, bias=True),
torch.nn.ReLU(),
)
.to(device)
.to(dtype)
)
.to(device)
.to(dtype)
)
ref = model(example_input)
torchao.autoquant(
model,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
)
out = model(example_input)
from torchao.quantization.autoquant import (
BFloat16Tensor,
Float16Tensor,
Float32Tensor,
)
ref = model(example_input)
qtensor_class_list = [qclass]
torchao.autoquant(
model,
qtensor_class_list=qtensor_class_list,
)
out = model(example_input)
self.assertIn(
type(model[1].weight),
qtensor_class_list,
)
self.assertGreater(compute_error(out, ref), 40)

self.assertIn(
type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor]
)
print(compute_error(out, ref))
self.assertGreater(compute_error(out, ref), 60)
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
)
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_autoquant_int4wo(self, device, dtype):
if device == "cpu":
self.skipTest(f"int4wo is for cuda, not {device}")

m, k, n = 128, 128, 128
example_input = torch.randn(m, k, device=device, dtype=dtype)

for qclass in [
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
AQInt4G32WeightOnlyQuantizedLinearWeight,
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
]:
model = (
torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k, n, bias=True),
torch.nn.ReLU(),
)
.to(device)
.to(dtype)
)
ref = model(example_input)
qtensor_class_list = [qclass]
torchao.autoquant(
model,
qtensor_class_list=qtensor_class_list,
)
out = model(example_input)

self.assertIn(type(model[1].weight), qtensor_class_list)
self.assertGreater(compute_error(ref, out), 20)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
)
def test_autoquant_float8(self, device, dtype):
if device == "cpu":
self.skipTest(f"int4wo is for cuda, not {device}")

# note: marlin sparse layout failed when scale_t has a dimension of 1d
m, k, n = 128, 128, 128
example_input = torch.randn(m, k, device=device, dtype=dtype)

for qclass in [
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
]:
model = (
torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k, n, bias=True),
torch.nn.ReLU(),
)
.to(device)
.to(dtype)
)
ref = model(example_input)
qtensor_class_list = [qclass]
torchao.autoquant(
model,
qtensor_class_list=qtensor_class_list,
)
out = model(example_input)

self.assertIn(type(model[1].weight), qtensor_class_list)
self.assertGreater(compute_error(ref, out), 20)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
Expand Down
5 changes: 5 additions & 0 deletions torchao/dtypes/uintx/marlin_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ def from_plain(
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
q_w_24 = int_data.t()
# addressing the case when scale has dimension 1, happens when
# weight_shape[-1] == group_size == 128
if scale.ndim == 1:
scale = scale.reshape(scale.shape[0], -1)

scale_t = scale.t()

if not torch.cuda.get_device_capability()[0] >= 8:
Expand Down
Loading

0 comments on commit cf45336

Please sign in to comment.