Skip to content

Commit

Permalink
Fix for weights-only load
Browse files Browse the repository at this point in the history
stack-info: PR: #1228, branch: drisspg/stack/19
  • Loading branch information
drisspg committed Nov 6, 2024
1 parent 6fd77d5 commit 9dd5401
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 57 deletions.
4 changes: 3 additions & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ include = [
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py"
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
"torchao/prototype/low_bit_optim/**.py",
"test/prototype/low_bit_optim/**.py",

]
104 changes: 78 additions & 26 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
)

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
x_rep
)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

Expand All @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):


class TestOptim(TestCase):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize(
"optim_name",
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
Expand All @@ -120,7 +131,7 @@ def test_optim_smoke(self, optim_name, dtype, device):
# test serialization. also test the case CUDA optim loads CPU state dict
with tempfile.NamedTemporaryFile() as f:
torch.save(optim.state_dict(), f.name)
state_dict = torch.load(f.name, map_location="cpu")
state_dict = torch.load(f.name, map_location="cpu", weights_only=True)

model2 = copy.deepcopy(model)
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
Expand All @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="bitsandbytes 8-bit Adam only works for CUDA",
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):

# this will not run in CI because we can't install lpmm
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
)
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
)
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
Expand Down Expand Up @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model1[0].requires_grad_(
False
) # make sure it can work in the presence of non-trainable params
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
model2.parameters(),
torch.optim.AdamW,
offload_gradients=offload_grad,
)

for _ in range(2):
Expand All @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
)
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
optim1 = low_bit_optim.CPUOffloadOptimizer(
model1.parameters(), torch.optim.AdamW
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand All @@ -249,11 +289,13 @@ def test_optim_cpu_offload_save_load(self):
# save checkpoint. make sure it can be serialized by torch.save()
with tempfile.NamedTemporaryFile() as file:
torch.save(optim1.state_dict(), file.name)
state_dict = torch.load(file.name, map_location="cpu")
state_dict = torch.load(file.name, map_location="cpu", weights_only=True)

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand All @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
device
)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
optim2 = low_bit_optim._AdamW(
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
)

# overfit on this sample
x = torch.randn(4, 32, device=device)
Expand All @@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
optim2.step()
optim2.zero_grad()

torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
torch.testing.assert_close(
loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}"
)


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
)
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
Expand Down Expand Up @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

Expand Down
47 changes: 36 additions & 11 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import torch
from torch import Tensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)

from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap
from .quant_utils import (
create_dynamic_map,
scale_tensor,
quantize_4bit_with_qmap,
dequant_with_qmap,
)


aten = torch.ops.aten
Expand Down Expand Up @@ -55,8 +64,12 @@ def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed, self._shape]

@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
return cls(
*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes
)

def dequantize(self, output_dtype=None):
codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack
Expand Down Expand Up @@ -85,6 +98,7 @@ def __repr__(self):
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
# dtype is the same but device is different. thus, we must override .to() method instead.
if not TORCH_VERSION_AT_LEAST_2_4:

def _to(self, *args, **kwargs):
# ignore other args/kwargs
device = kwargs.pop("device", None)
Expand Down Expand Up @@ -158,16 +172,20 @@ def _(func, types, args, kwargs):
if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")
raise ValueError(
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
)


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
@OptimState4bit.implements(
[
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
]
)
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
Expand All @@ -181,3 +199,10 @@ def _(func, types, args, kwargs):

# assume tensors from all ranks have the same signedness
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)


if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([OptimState4bit])
43 changes: 33 additions & 10 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import torch
from torch import Tensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)

from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap
from .quant_utils import (
create_dynamic_map,
scale_tensor,
quantize_8bit_with_qmap,
dequant_with_qmap,
)


aten = torch.ops.aten
Expand Down Expand Up @@ -46,8 +55,12 @@ def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed]

@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
return cls(
*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes
)

def dequantize(self, output_dtype=None):
float_data = dequant_with_qmap(self.codes, self.qmap, self.scale)
Expand All @@ -72,6 +85,7 @@ def __repr__(self):
# in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
# dtype is the same but device is different. thus, we must override .to() method instead.
if not TORCH_VERSION_AT_LEAST_2_4:

def _to(self, *args, **kwargs):
# ignore other args/kwargs
device = kwargs.pop("device", None)
Expand Down Expand Up @@ -136,12 +150,14 @@ def _(func, types, args, kwargs):


# this is needed for DTensor.full_tensor()
@OptimState8bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
@OptimState8bit.implements(
[
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
]
)
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
Expand All @@ -154,3 +170,10 @@ def _(func, types, args, kwargs):
x.qmap.clone(),
x.signed,
)


if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals

add_safe_globals([OptimState8bit])
Loading

0 comments on commit 9dd5401

Please sign in to comment.