From e55e795687a71b0058ce5a1af148713724666353 Mon Sep 17 00:00:00 2001 From: ItsNiklas Date: Wed, 21 Jun 2023 09:52:41 +0200 Subject: [PATCH] enh: Check for Sophia before performing the Hessian update step This makes swapping optimizers easier, see !4 --- classifier.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/classifier.py b/classifier.py index bcaf93d..8bcaaa0 100644 --- a/classifier.py +++ b/classifier.py @@ -301,7 +301,11 @@ def train(args): optimizer.zero_grad(set_to_none=True) - if iter_num % hess_interval == hess_interval - 1: + # Check if we use the Sophia Optimizer + if ( + hasattr(optimizer, "update_hessian") + and iter_num % hess_interval == hess_interval - 1 + ): # Update the Hessian EMA logits = model(b_ids, b_mask) samp_dist = torch.distributions.Categorical(logits=logits)