From 9dd5401169213c43619504c39b6c9246795b2f38 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 5 Nov 2024 17:15:07 -0800 Subject: [PATCH] Fix for weights-only load stack-info: PR: https://github.com/pytorch/ao/pull/1228, branch: drisspg/stack/19 --- ruff.toml | 4 +- test/prototype/test_low_bit_optim.py | 104 +++++++++++++----- .../prototype/low_bit_optim/subclass_4bit.py | 47 ++++++-- .../prototype/low_bit_optim/subclass_8bit.py | 43 ++++++-- .../prototype/low_bit_optim/subclass_fp8.py | 31 ++++-- 5 files changed, 172 insertions(+), 57 deletions(-) diff --git a/ruff.toml b/ruff.toml index 1a4a5ff097..4b6c81385c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,6 +11,8 @@ include = [ "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", "test/dtypes/test_affine_quantized_float.py", - "torchao/quantization/weight_tensor_linear_activation_quantization.py" + "torchao/quantization/weight_tensor_linear_activation_quantization.py", + "torchao/prototype/low_bit_optim/**.py", + "test/prototype/low_bit_optim/**.py", ] diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 39f97896bf..e5053bdf4a 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -19,7 +19,11 @@ quantize_4bit_with_qmap, _fp32_to_bf16_sr, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_6, +) try: import bitsandbytes as bnb @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile): x_rep = x.view(-1, 1).repeat(1, 100_000) if compile: - x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep) + x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)( + x_rep + ) else: x_rep_bf16 = _fp32_to_bf16_sr(x_rep) @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile): class TestOptim(TestCase): - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") - @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) + @parametrize( + "optim_name", + ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"], + ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_smoke(self, optim_name, dtype, device): @@ -120,7 +131,7 @@ def test_optim_smoke(self, optim_name, dtype, device): # test serialization. also test the case CUDA optim loads CPU state dict with tempfile.NamedTemporaryFile() as f: torch.save(optim.state_dict(), f.name) - state_dict = torch.load(f.name, map_location="cpu") + state_dict = torch.load(f.name, map_location="cpu", weights_only=True) model2 = copy.deepcopy(model) optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device): torch.testing.assert_close(p2, p1) @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="bitsandbytes 8-bit Adam only works for CUDA", + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name): # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA" + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # lpmm doesn't have Adam. use AdamW with no weight decay instead. @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + model1[0].requires_grad_( + False + ) # make sure it can work in the presence of non-trainable params model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) optim2 = low_bit_optim.CPUOffloadOptimizer( - model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad, + model2.parameters(), + torch.optim.AdamW, + offload_gradients=offload_grad, ) for _ in range(2): @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) def test_optim_cpu_offload_save_load(self): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + optim1 = low_bit_optim.CPUOffloadOptimizer( + model1.parameters(), torch.optim.AdamW + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -249,11 +289,13 @@ def test_optim_cpu_offload_save_load(self): # save checkpoint. make sure it can be serialized by torch.save() with tempfile.NamedTemporaryFile() as file: torch.save(optim1.state_dict(), file.name) - state_dict = torch.load(file.name, map_location="cpu") + state_dict = torch.load(file.name, map_location="cpu", weights_only=True) # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW + ) optim2.load_state_dict(state_dict) for _ in range(2): @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self): def test_optim_bf16_stochastic_round_correctness(self): device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1).bfloat16() # small LR so that weight update is small # when bf16_stochastic_round=False, the test will fail after 1 iteration optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) - optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) + optim2 = low_bit_optim._AdamW( + model2.parameters(), lr=1e-5, bf16_stochastic_round=True + ) # overfit on this sample x = torch.randn(4, 32, device=device) @@ -299,7 +345,9 @@ def test_optim_bf16_stochastic_round_correctness(self): optim2.step() optim2.zero_grad() - torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") + torch.testing.assert_close( + loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}" + ) class TestFSDP2(FSDPTest): @@ -307,7 +355,9 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return 2 - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.") + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required." + ) @skip_if_lt_x_gpu(2) def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls): base_loss.backward() for param in base_model.parameters(): if param.grad is not None: - torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) base_optim.step() self.assertEqual(fsdp_loss, base_loss) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 759d816a6e..257e03afca 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -3,9 +3,18 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, +) -from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap +from .quant_utils import ( + create_dynamic_map, + scale_tensor, + quantize_4bit_with_qmap, + dequant_with_qmap, +) aten = torch.ops.aten @@ -55,8 +64,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack @@ -85,6 +98,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -158,16 +172,20 @@ def _(func, types, args, kwargs): if len(shape) == 1 and shape[0] == -1: return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) - raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + raise ValueError( + f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) # this is needed for DTensor.full_tensor() -@OptimState4bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState4bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState4bit): @@ -181,3 +199,10 @@ def _(func, types, args, kwargs): # assume tensors from all ranks have the same signedness return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Needed to load Float8Tensor with weights_only = True + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState4bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index f5374a3480..99b92cd88b 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,9 +1,18 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, +) -from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap +from .quant_utils import ( + create_dynamic_map, + scale_tensor, + quantize_8bit_with_qmap, + dequant_with_qmap, +) aten = torch.ops.aten @@ -46,8 +55,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = dequant_with_qmap(self.codes, self.qmap, self.scale) @@ -72,6 +85,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -136,12 +150,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimState8bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState8bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState8bit): @@ -154,3 +170,10 @@ def _(func, types, args, kwargs): x.qmap.clone(), x.signed, ) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Needed to load Float8Tensor with weights_only = True + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState8bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index eabe8b5051..fb86a1dd42 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor +from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5 aten = torch.ops.aten @@ -51,8 +51,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = self.codes.float() @@ -121,12 +125,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimStateFp8.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimStateFp8.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimStateFp8): @@ -137,3 +143,10 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Needed to load Float8Tensor with weights_only = True + from torch.serialization import add_safe_globals + + add_safe_globals([OptimStateFp8])