Skip to content

Commit

Permalink
fix: missing var def
Browse files Browse the repository at this point in the history
  • Loading branch information
Derek Pisner authored and Derek Pisner committed Feb 8, 2024
1 parent 657b4f0 commit e382e16
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def test_reset(optimizer):
@pytest.mark.parametrize("accum_steps", [1, 2, 3])
@pytest.mark.parametrize("use_agc", [True, False])
def test_parse_grad(optimizer, use_agc, accum_steps):
var = tf.Variable([1.0, 2.0], dtype=tf.float32)
if accum_steps == 1:
expected_grad = tf.zeros_like(var) # gradients should not be applied yet
else:
expected_grad = tf.constant([3.0, 4.0])
var = tf.Variable([1.0, 2.0], dtype=tf.float32)
optimizer.add_slot(var, "ga", initializer=expected_grad)
accum_gradient = optimizer.get_slot(var, "ga")

Expand Down

0 comments on commit e382e16

Please sign in to comment.