From 9ec0eddddc10ba08cbb62936c3f9ead392df9f77 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:10:28 -0800 Subject: [PATCH] Fix for weights-only load (#1228) stack-info: PR: https://github.com/pytorch/ao/pull/1228, branch: drisspg/stack/19 --- test/prototype/test_low_bit_optim.py | 5 +++-- torchao/prototype/low_bit_optim/adam.py | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 5387b49803..c578a82a8d 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -368,8 +368,9 @@ def test_optim_bf16_stochastic_round_correctness(self): optim2.step() optim2.zero_grad() - torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") - + torch.testing.assert_close( + loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}" + ) _FSDP_WORLD_SIZE = 2 diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 57606f787d..380d730e1c 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -13,12 +13,12 @@ class _AdamBase(Optimizer): def __init__( - self, - params, - lr, - betas, - eps, - weight_decay, + self, + params, + lr, + betas, + eps, + weight_decay, amsgrad, *, block_size,