Skip to content

Commit

Permalink
rebase complete
Browse files Browse the repository at this point in the history
  • Loading branch information
asahni04 committed Nov 7, 2024
1 parent bcc8a94 commit a4e320c
Showing 1 changed file with 0 additions and 35 deletions.
35 changes: 0 additions & 35 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,9 @@ def test_optim_exclude_low_bit_params(self, optim_name, dtype, device):
)
model.to(device=device, dtype=dtype)

<<<<<<< HEAD
params_to_exclude = [model[0].weight, model[0].bias]
excluded_params_ids = set(id(p) for p in params_to_exclude)

=======
params_to_exclude = [model[0].weight]
>>>>>>> 9e829b7 (support exlusion of params when using low bit optim)

optim = getattr(low_bit_optim, optim_name)(
model.parameters(),
Expand All @@ -390,7 +386,6 @@ def test_optim_exclude_low_bit_params(self, optim_name, dtype, device):
excluded_state = state[excluded_param]
exp_avg = excluded_state['exp_avg']
exp_avg_sq = excluded_state['exp_avg_sq']
<<<<<<< HEAD
# Assert that the state tensors for the excluded parameter are not quantized
self.assertTrue(exp_avg.__class__ == torch.Tensor)
self.assertTrue(exp_avg_sq.__class__ == torch.Tensor)
Expand All @@ -401,36 +396,6 @@ def test_optim_exclude_low_bit_params(self, optim_name, dtype, device):
exp_avg_sq = param_state['exp_avg_sq']
self.assertTrue(exp_avg.__class__ != torch.Tensor)
self.assertTrue(exp_avg_sq.__class__ != torch.Tensor)
=======

if optim_name.endswith("8bit"):
quantized_state_types = (low_bit_optim.OptimState8bit,)
elif optim_name.endswith("4bit"):
quantized_state_types = (low_bit_optim.OptimState4bit,)
elif optim_name.endswith("Fp8"):
quantized_state_types = (low_bit_optim.OptimStateFp8,)
else:
quantized_state_types = ()

# Assert that the state tensors for the excluded parameter are not quantized
self.assertNotIsInstance(exp_avg, quantized_state_types)
self.assertNotIsInstance(exp_avg_sq, quantized_state_types)

for param in model.parameters():
if param is not excluded_param:
param_state = state[param]
exp_avg = param_state['exp_avg']
exp_avg_sq = param_state['exp_avg_sq']
self.assertIsInstance(exp_avg, quantized_state_types)
self.assertIsInstance(exp_avg_sq, quantized_state_types)

# Since the excluded parameter is not quantized, its data type should remain the same
self.assertEqual(excluded_param.dtype, dtype)

# Ensure that other parameters are still being updated correctly
for param in model.parameters():
self.assertIsNotNone(param.grad)
>>>>>>> 9e829b7 (support exlusion of params when using low bit optim)

class TestFSDP2(FSDPTest):
@property
Expand Down

0 comments on commit a4e320c

Please sign in to comment.