Skip to content

Commit

Permalink
test(cov): lr setter
Browse files Browse the repository at this point in the history
  • Loading branch information
Derek Pisner authored and Derek Pisner committed Feb 11, 2024
1 parent 8d87ab9 commit 58e42e8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion gradient_accumulator/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __init__(
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
if not hasattr(self, "_weights"):
self._weights = [] # noqa
self._weights = [] # pragma: no cover
if not hasattr(self, "_gradients"):
self._gradients = []
self._weights.append(self._step)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ def test_lr_getter(optimizer):

def test_lr_setter(optimizer):
optimizer.lr = 0.02
updated_lr = optimizer.lr.numpy() if hasattr(optimizer.lr, 'numpy') else optimizer.lr

assert optimizer.lr == 0.02, "The lr getter did not return the updated learning rate."
assert updated_lr == pytest.approx(0.02), "The lr getter did not return the updated learning rate."

assert optimizer.base_optimizer.learning_rate == 0.02
assert optimizer._learning_rate == 0.02
base_lr = optimizer.base_optimizer.learning_rate
base_lr = base_lr.numpy() if hasattr(base_lr, 'numpy') else base_lr
assert base_lr == pytest.approx(0.02), "The base_optimizer's learning rate was not updated."

internal_lr = optimizer._learning_rate
internal_lr = internal_lr.numpy() if hasattr(internal_lr, 'numpy') else internal_lr
assert internal_lr == pytest.approx(0.02), "The internal _learning_rate attribute was not updated."

def test__learning_rate(optimizer):
assert optimizer._learning_rate == 0.01
Expand All @@ -45,7 +51,7 @@ def test_step_setter(optimizer):
def test_iterations_setter(optimizer):
optimizer.iterations = 1
assert optimizer.iterations == 1

def test_optimizer_prop(optimizer):
assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__

Expand Down

0 comments on commit 58e42e8

Please sign in to comment.