Skip to content

Commit

Permalink
enh: Check for Sophia before performing the Hessian update step
Browse files Browse the repository at this point in the history
This makes swapping optimizers easier, see !4
  • Loading branch information
ItsNiklas committed Jun 21, 2023
1 parent 8ce9763 commit e55e795
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def train(args):

optimizer.zero_grad(set_to_none=True)

if iter_num % hess_interval == hess_interval - 1:
# Check if we use the Sophia Optimizer
if (
hasattr(optimizer, "update_hessian")
and iter_num % hess_interval == hess_interval - 1
):
# Update the Hessian EMA
logits = model(b_ids, b_mask)
samp_dist = torch.distributions.Categorical(logits=logits)
Expand Down

0 comments on commit e55e795

Please sign in to comment.