Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up linear_int8_dynamic_activation_intx_weight_subclass #1553

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,32 +543,22 @@ def ffn_or_attn_only(mod, fqn):
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.granularity import PerGroup

assert (
precision == torch.float32
), "int8_dynamic_activation_intx_weight requires fp32 precision"

try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except:
print(
"Unable to load experimental torchao kernels. Performance will be slow."
)
print(
"To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU"
)
), "int8_dynamic_activation_intx_weight requires using precision=torch.float32"

# Quantize model
_quant_args = quantization.split("-")
nbit = int(_quant_args[1])
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
group_size = int(_quant_args[2])
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
granularity = PerGroup(int(_quant_args[2]))
has_weight_zeros = bool(_quant_args[3])
quantize_(
model,
int8_dynamic_activation_intx_weight(
group_size=group_size,
nbit=nbit,
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
),
)
Expand Down
18 changes: 11 additions & 7 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
Expand All @@ -55,7 +55,7 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
Expand All @@ -64,6 +64,8 @@ def __init__(
self._layout = _layout

def __tensor_flatten__(self):
if self.zero_point is None:
return ["int_data", "scale"], [self._layout]
return ["int_data", "scale", "zero_point"], [self._layout]

@classmethod
Expand All @@ -73,7 +75,7 @@ def __tensor_unflatten__(
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)
Expand All @@ -83,15 +85,17 @@ def to(self, *args, **kwargs):
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.zero_point.to(kwargs["device"])
if self.zero_point is not None
else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

Expand Down Expand Up @@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
Expand All @@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
Expand Down
6 changes: 3 additions & 3 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor):
the underlying implementation of a AQT based on layout
"""

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Get the plain (unpacked) Tensor for the tensor impl
Returns data, scale and zero_point
Expand All @@ -103,7 +103,7 @@ def from_plain(
cls,
data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
"""Construct a TensorImpl from data, scale, zero_point and the _layout"""
Expand Down
Loading
Loading