Skip to content

Commit

Permalink
fix: moved hessian update
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 19, 2023
1 parent 41436e2 commit bd3eef1
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,16 @@ def train(args):
b_mask = b_mask.to(device)
b_labels = b_labels.to(device)

if iter_num % k == 1:
# Update the Hessian EMA
logits = model(b_ids, torch.ones_like(b_mask))
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
loss_sampled = F.cross_entropy(logits, y_sample.view(-1), reduction='sum') / args.batch_size
loss_sampled.backward()
optimizer.update_hessian()
optimizer.zero_grad(set_to_none=True)

optimizer.zero_grad()
logits = model(b_ids, b_mask)
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size
Expand All @@ -290,17 +300,6 @@ def train(args):
train_loss += loss.item()
num_batches += 1

if iter_num % k == k - 1:
# Update the Hessian EMA
logits = model(b_ids, torch.ones_like(b_mask))
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
loss_sampled = F.cross_entropy(logits, y_sample.view(-1), reduction='sum') / args.batch_size
loss_sampled.backward()
optimizer.update_hessian()
optimizer.zero_grad(set_to_none=True)


iter_num += 1

train_loss = train_loss / (num_batches)
Expand Down

0 comments on commit bd3eef1

Please sign in to comment.