diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index acc7576e56..fd75d73676 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -25,7 +25,7 @@ ) from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit -from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 +from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8, quantize_fp8 from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, @@ -42,6 +42,10 @@ except ImportError: lpmm = None +try: + import coat +except ImportError: + coat = None _DEVICES = get_available_devices() @@ -152,6 +156,55 @@ def test_optim_smoke(self, optim_name, dtype, device): for p1, p2 in zip(model.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) + @parametrize( + "optim_name", + ["AdamFp8", "AdamWFp8"], + ) + @parametrize("dtype", [torch.float32, torch.bfloat16]) + @parametrize("device", _DEVICES) + def test_optim_fp8_coat_smoke(self, optim_name, dtype, device): + if device == "cuda": + if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("FP8 CUDA requires PyTorch >= 2.4") + if torch.cuda.get_device_capability() < (8, 9): + pytest.skip("FP8 CUDA requires compute capability >= 8.9") + + model = nn.Sequential(nn.Linear(32, 4096), nn.ReLU(), nn.Linear(4096, 4096)) + model.to(device=device, dtype=dtype) + + optim = getattr(low_bit_optim, optim_name)( + model.parameters(), dynamic_range_expansion=True + ) + + x = torch.randn(4, 32, device=device, dtype=dtype) + loss = model(x).sum() + loss.backward() + optim.step() + optim.zero_grad() + + # test serialization. also test the case CUDA optim loads CPU state dict + with tempfile.NamedTemporaryFile() as f: + torch.save(optim.state_dict(), f.name) + state_dict = torch.load(f.name, map_location="cpu") + + model2 = copy.deepcopy(model) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + optim2.load_state_dict(state_dict) + + for _ in range(2): + x = torch.randn(4, 32, device=device, dtype=dtype) + + model(x).sum().backward() + optim.step() + optim.zero_grad() + + model2(x).sum().backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1) + # aten.slice is required for dcp.load() when world size changes i.e. re-sharding # however, it's cumbersome to test it directly, since we would need to run distributed # test 2 times with different world size, and persist checkpoint across the 2 runs. @@ -216,6 +269,51 @@ def test_optim_8bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + @pytest.mark.skipif(coat is None, reason="Coat is not available") + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Coat float8 Adam only works for CUDA", + ) + @parametrize("optim_name", ["AdamWFp8"]) + def test_optim_float8_correctness(self, optim_name): + + from coat.activation.models._fp8_quantization_config import QuantizationConfig + from coat.optimizer.fp8_adamw import CoatAdamW + + torch.manual_seed(42) + device = "cuda" + + + model1 = nn.Sequential(nn.Linear(32, 4096), nn.ReLU(), nn.Linear(4096, 4096)) + model1.to(device) + model2 = copy.deepcopy(model1) + + # Official CoatOptim only supports 128 + block_size = 128 + coat_args = QuantizationConfig(first_order_expansion="true", second_order_expansion="true") + + + optim1 = CoatAdamW(coat_args, model1.parameters()) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size, dynamic_range_expansion=True + ) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + + loss1 = model1(x).sum() + loss1.backward() + optim1.step() + optim1.zero_grad() + + loss2 = model2(x).sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not available") @pytest.mark.skipif( diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 2e0f9725a4..011687f210 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -11,8 +11,10 @@ - `galore/docs` - implementation notes and discussion of issues faced in kernel design. - [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) - [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). + * `dynamic_range_expansion` - implementing additional heuristic _expand & compress_ method before quantizing and after dequantizing of Optimizer states for float8 optimizers. [COAT](https://arxiv.org/abs/2410.19313) - [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406) + #### Roadmap - `hqq`, `awq`, `marlin`,`QuaRot`, and other well-researched methodologies for quantized fine-tuning and inference. diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 9cad9777bf..56d5dee4f5 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -51,8 +51,7 @@ def __setstate__(self, state): group.setdefault("amsgrad", False) # bring your own function to create zero-filled subclass - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + def _subclass_zeros(self, p: Tensor, signed: bool): raise NotImplementedError def _new_buffer(self, p: Tensor, signed: bool): @@ -60,7 +59,7 @@ def _new_buffer(self, p: Tensor, signed: bool): # follow bitsandbytes, only quantize tensors >= 4096 values if local_p.numel() >= 4096 and local_p.numel() % self.block_size == 0: - out = self._subclass_zeros(local_p, signed, self.block_size) + out = self._subclass_zeros(local_p, signed) else: out = torch.zeros_like(local_p) @@ -216,9 +215,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class Adam4bit(_AdamBase): @@ -246,9 +244,8 @@ def __init__( is_adamw=False, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamFp8(_AdamBase): @@ -263,6 +260,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -275,10 +273,12 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=False, ) + self.dynamic_range_expansion = dynamic_range_expansion - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) class AdamW8bit(_AdamBase): @@ -306,9 +306,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState8bit.zeros(p.shape, signed, self.block_size, p.device) class AdamW4bit(_AdamBase): @@ -336,9 +335,8 @@ def __init__( is_adamw=True, ) - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimState4bit.zeros(p.shape, signed, self.block_size, p.device) class AdamWFp8(_AdamBase): @@ -353,6 +351,7 @@ def __init__( *, block_size=256, bf16_stochastic_round=False, + dynamic_range_expansion=False, ) -> None: super().__init__( params, @@ -365,10 +364,12 @@ def __init__( bf16_stochastic_round=bf16_stochastic_round, is_adamw=True, ) + self.dynamic_range_expansion = dynamic_range_expansion - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) + def _subclass_zeros(self, p: Tensor, signed: bool): + return OptimStateFp8.zeros( + p.shape, self.block_size, p.device, self.dynamic_range_expansion + ) class _AdamW(_AdamBase): diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b5c8af6c83..3df0df701e 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,4 +1,5 @@ import math +from typing import Optional import torch from torch import Tensor @@ -13,13 +14,47 @@ DTYPE = torch.float8_e4m3fn -def quantize_fp8(input: Tensor, block_size: int): +def quantize_fp8( + input: Tensor, block_size: int, dynamic_range_expansion: bool, ifprint=False +): shape = input.shape input = input.view(-1, block_size) - scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max + k = None + SqrtMinMax = None + + scale = input.abs().amax(-1).clip(1e-12) + + if dynamic_range_expansion: + # NOTE: the calculation is from the paper https://arxiv.org/abs/2410.19313 + # The idea is to align optimizer state distributions more closely + # with the FP8 representation range, reducing the quantization error. + + k = torch.ones(input.shape[0], device=input.device) + expand_min = torch.tensor(16.0, device=input.device).view(-1, 1) + Rdtype = torch.tensor( + torch.finfo(DTYPE).max * torch.finfo(DTYPE).max / 2, device=input.device + ).view(-1, 1) + + MaxValue = (input.abs().amax(-1).clip(1e-20)).view(-1, 1) + MinValue = (input.abs().amin(-1).clip(1e-20)).view(-1, 1) + SqrtMinMax = torch.sqrt(MaxValue * MinValue) # geomatric mean of max and min + + Rx = MaxValue / MinValue # range of input max and min + + k = ( + torch.floor((torch.log2(Rdtype) / torch.log2(Rx)) * expand_min) / expand_min + ).view(-1) # calculating optimal value k dynamically + + scale = (MaxValue / SqrtMinMax) ** k.view(-1, 1) + input = input.sign() * (input.abs().div(SqrtMinMax) ** k.view(-1, 1)) + k = k.view(-1) + SqrtMinMax = SqrtMinMax.view(-1) + + scale = scale / torch.finfo(DTYPE).max input = input / scale.view(-1, 1) codes = input.to(DTYPE).view(-1) - return codes.view(shape), scale + + return codes.view(shape), scale.view(-1), k, SqrtMinMax # NOTE: FP8 sign bit is redundant for unsigned optim state. @@ -29,10 +64,22 @@ class OptimStateFp8(TorchAOBaseTensor): tensor_attrs = ["codes", "scale"] @staticmethod - def __new__(cls, codes: Tensor, scale: Tensor): + def __new__( + cls, + codes: Tensor, + scale: Tensor, + k: Optional[Tensor] = None, + sqrt_minmax_exp: Optional[Tensor] = None, + ): return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) - def __init__(self, codes: Tensor, scale: Tensor): + def __init__( + self, + codes: Tensor, + scale: Tensor, + k: Optional[Tensor] = None, + sqrt_minmax_exp: Optional[Tensor] = None, + ): """Create quantized FP8 optimizer state. Args @@ -47,32 +94,65 @@ def __init__(self, codes: Tensor, scale: Tensor): assert scale.ndim == 1 self.codes = codes self.scale = scale + self.k = k self.block_size = codes.numel() // scale.numel() + self.sqrt_minmax_exp = sqrt_minmax_exp def __tensor_flatten__(self): - return self.tensor_attrs, [] + return self.tensor_attrs + ( + ["k", "sqrt_minmax_exp"] if self.k is not None else [] + ), [] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): return cls( - *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + *[ + tensor_data_dict[name] + for name in cls.tensor_attrs + + (["k", "sqrt_minmax_exp"] if "k" in tensor_data_dict else []) + ], + *tensor_attributes, ) def dequantize(self, output_dtype=None): float_data = self.codes.float() float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + if self.k is not None: + float_data = ( + float_data.sign() + * (float_data.abs() ** (1 / self.k.view(-1, 1))) + * self.sqrt_minmax_exp.view(-1, 1) + ) + if output_dtype is not None: float_data = float_data.to(output_dtype) + return float_data.view(self.codes.shape) @classmethod - def zeros(cls, shape, block_size: int = 256, device=None): + def zeros( + cls, + shape, + block_size: int = 256, + device=None, + dynamic_range_expansion: bool = False, + ): codes = torch.zeros(shape, dtype=DTYPE, device=device) scale = torch.zeros(codes.numel() // block_size, device=device) - return cls(codes, scale) + k = ( + torch.ones(codes.numel() // block_size, device=device) + if dynamic_range_expansion + else None + ) + sqrt_minmax_exp = ( + torch.ones(codes.numel() // block_size, device=device) + if dynamic_range_expansion + else None + ) + return cls(codes, scale, k, sqrt_minmax_exp) def __repr__(self): return ( @@ -90,12 +170,21 @@ def _(func, types, args, kwargs): assert dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) + if dst.k is not None: + dst.k.copy_(src.k) + dst.sqrt_minmax_exp.copy_(src.sqrt_minmax_exp) elif isinstance(dst, OptimStateFp8): - codes, scale = quantize_fp8(src, dst.block_size) + codes, scale, k, sqrt_minmax_exp = quantize_fp8( + src, dst.block_size, True if dst.k is not None else False + ) + dst.codes.copy_(codes) dst.scale.copy_(scale) - + # Used for computation of dynamic range expansion + if dst.k is not None: + dst.k.copy_(k) + dst.sqrt_minmax_exp.copy_(sqrt_minmax_exp) else: dst.copy_(src.dequantize()) @@ -109,6 +198,8 @@ def _(func, types, args, kwargs): out = OptimStateFp8( args[0].codes.to(device=device), args[0].scale.to(device=device), + args[0].k.to(device=device) if args[0].k is not None else None, + args[0].sqrt_minmax_exp.to(device=device) if args[0].k is not None else None, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -123,7 +214,7 @@ def _(func, types, args, kwargs): @OptimStateFp8.implements(aten.view.default) def _(func, types, args, kwargs): x, shape = args - return OptimStateFp8(x.codes.view(shape), x.scale) + return OptimStateFp8(x.codes.view(shape), x.scale, x.k, x.sqrt_minmax_exp) @OptimStateFp8.implements( @@ -146,6 +237,8 @@ def _(func, types, args, kwargs): return OptimStateFp8( func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), + func(x.k, *args[1:], **kwargs) if x.k else None, + func(x.sqrt_minmax_exp, *args[1:], **kwargs) if x.k else None, )