diff --git a/classifier.py b/classifier.py index b88e96d..c069a78 100644 --- a/classifier.py +++ b/classifier.py @@ -260,7 +260,7 @@ def train(args): lr = args.lr #optimizer = AdamW(model.parameters(), lr=lr) - optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.01, weight_decay=0.01) + optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.03, weight_decay=0.0) k = 10 iter_num = 0 @@ -285,17 +285,21 @@ def train(args): loss.backward() optimizer.step(bs = args.batch_size) + optimizer.zero_grad(set_to_none=True) train_loss += loss.item() num_batches += 1 - if iter_num % k != k - 1: + if iter_num % k == k - 1: + # Update the Hessian EMA logits = model(b_ids, torch.ones_like(b_mask)) samp_dist = torch.distributions.Categorical(logits=logits) y_sample = samp_dist.sample() loss_sampled = F.cross_entropy(logits, y_sample.view(-1), reduction='sum') / args.batch_size loss_sampled.backward() optimizer.update_hessian() + optimizer.zero_grad(set_to_none=True) + iter_num += 1 diff --git a/optimizer.py b/optimizer.py index e2f9e6e..3180294 100644 --- a/optimizer.py +++ b/optimizer.py @@ -149,7 +149,7 @@ def __init__( @torch.no_grad() def update_hessian(self): for group in self.param_groups: - beta1, beta2 = group["betas"] + _, beta2 = group["betas"] for p in group["params"]: if p.grad is None: continue @@ -186,7 +186,7 @@ def step(self, closure: Callable = None, bs: int = 5120): state["hessian"] = torch.zeros_like(p) # Access hyperparameters from the `group` dictionary - beta1, beta2 = group["betas"] + beta1, _ = group["betas"] rho = group["rho"] exp_avg = state["exp_avg"] lr = group["lr"] @@ -202,12 +202,10 @@ def step(self, closure: Callable = None, bs: int = 5120): state["exp_avg"] = beta1 * exp_avg + (1 - beta1) * grad # 3 - Decay the hessian running average coefficient - step_size_neg = -lr - ratio = (exp_avg.abs() / (rho * bs * state["hessian"] + 1e-15)).clamp( None, 1 ) - p.data = p.data + exp_avg.sign() * ratio * step_size_neg + p.data = p.data + exp_avg.sign() * ratio * -lr return loss