Skip to content

Commit

Permalink
fix: Sophia training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 19, 2023
1 parent 492f9f3 commit 41436e2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 6 additions & 2 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def train(args):

lr = args.lr
#optimizer = AdamW(model.parameters(), lr=lr)
optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.01, weight_decay=0.01)
optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.03, weight_decay=0.0)
k = 10
iter_num = 0

Expand All @@ -285,17 +285,21 @@ def train(args):

loss.backward()
optimizer.step(bs = args.batch_size)
optimizer.zero_grad(set_to_none=True)

train_loss += loss.item()
num_batches += 1

if iter_num % k != k - 1:
if iter_num % k == k - 1:
# Update the Hessian EMA
logits = model(b_ids, torch.ones_like(b_mask))
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
loss_sampled = F.cross_entropy(logits, y_sample.view(-1), reduction='sum') / args.batch_size
loss_sampled.backward()
optimizer.update_hessian()
optimizer.zero_grad(set_to_none=True)


iter_num += 1

Expand Down
8 changes: 3 additions & 5 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
@torch.no_grad()
def update_hessian(self):
for group in self.param_groups:
beta1, beta2 = group["betas"]
_, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
Expand Down Expand Up @@ -186,7 +186,7 @@ def step(self, closure: Callable = None, bs: int = 5120):
state["hessian"] = torch.zeros_like(p)

# Access hyperparameters from the `group` dictionary
beta1, beta2 = group["betas"]
beta1, _ = group["betas"]
rho = group["rho"]
exp_avg = state["exp_avg"]
lr = group["lr"]
Expand All @@ -202,12 +202,10 @@ def step(self, closure: Callable = None, bs: int = 5120):
state["exp_avg"] = beta1 * exp_avg + (1 - beta1) * grad

# 3 - Decay the hessian running average coefficient
step_size_neg = -lr

ratio = (exp_avg.abs() / (rho * bs * state["hessian"] + 1e-15)).clamp(
None, 1
)

p.data = p.data + exp_avg.sign() * ratio * step_size_neg
p.data = p.data + exp_avg.sign() * ratio * -lr

return loss

0 comments on commit 41436e2

Please sign in to comment.