Skip to content

Commit

Permalink
✨ (optimizer): Implemented SophiaH.
Browse files Browse the repository at this point in the history
Not sure if it is actually optimizing yet
  • Loading branch information
ItsNiklas committed Aug 20, 2023
1 parent 55caa1c commit 62bf578
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 60 deletions.
77 changes: 23 additions & 54 deletions multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from bert import BertModel
from layers.AttentionLayer import AttentionLayer
from optimizer import AdamW, SophiaG
from optimizer import AdamW, SophiaH
from tqdm import tqdm

from datasets import (
Expand Down Expand Up @@ -278,8 +278,8 @@ def train_multitask(args):

if args.optimizer == "adamw":
optimizer = AdamW(model.parameters(), lr=lr)
elif args.optimizer == "sophiag":
optimizer = SophiaG(
elif args.optimizer == "sophiah":
optimizer = SophiaH(
model.parameters(), lr=lr, eps=1e-12, rho=0.03, betas=(0.985, 0.99), weight_decay=2e-1
)
else:
Expand Down Expand Up @@ -329,25 +329,14 @@ def train_multitask(args):

sts_loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()

# Check if we use the Sophia Optimizer
if args.optimizer == "sophiag" and num_batches % hess_interval == hess_interval - 1:
if args.optimizer == "sophiah" and num_batches % hess_interval == hess_interval - 1:
# Update the Hessian EMA
with ctx:
logits = model.predict_similarity(b_ids_1, b_mask_1, b_ids_2, b_mask_2)
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
# add a dimension, now logits shape is [1, bs] and logits is [1] (Which is wrong TODO)
loss_sampled = F.cross_entropy(logits.unsqueeze(0), y_sample.view(-1))
loss_sampled.backward()

# Potentially: Clip gradients using
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.update_hessian(bs=args.batch_size)
optimizer.zero_grad(set_to_none=True)
optimizer.update_hessian()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)

if args.scheduler == "cosine":
scheduler.step(epoch + num_batches / total_num_batches)
Expand Down Expand Up @@ -377,26 +366,15 @@ def train_multitask(args):
para_loss = F.mse_loss(logits, b_labels.view(-1))
para_loss.backward()

# Check if we use the Sophia Optimizer
if args.optimizer == "sophiah" and num_batches % hess_interval == hess_interval - 1:
# Update the Hessian EMA
optimizer.update_hessian()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)

optimizer.zero_grad()
# Check if we use the Sophia Optimizer
if args.optimizer == "sophiag" and num_batches % hess_interval == hess_interval - 1:
# Update the Hessian EMA
with ctx:
logits = model.predict_paraphrase(b_ids_1, b_mask_1, b_ids_2, b_mask_2)
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
# add a dimension, now logits shape is [1, bs] and logits is [1] (Which is wrong TODO)
# TODO SophiaH
loss_sampled = F.cross_entropy(logits.unsqueeze(0), y_sample.view(-1))
loss_sampled.backward()

# Potentially: Clip gradients using
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.update_hessian(bs=args.batch_size)
optimizer.zero_grad(set_to_none=True)

if args.scheduler == "cosine":
scheduler.step(epoch + num_batches / total_num_batches)
Expand All @@ -423,24 +401,15 @@ def train_multitask(args):
sst_loss = F.cross_entropy(logits, b_labels.view(-1))
sst_loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()

optimizer.zero_grad()
# Check if we use the Sophia Optimizer
if args.optimizer == "sophiag" and num_batches % hess_interval == hess_interval - 1:
# This is the only task potentially compatible with SophiaG
if args.optimizer == "sophiah" and num_batches % hess_interval == hess_interval - 1:
# Update the Hessian EMA
with ctx:
logits = model.predict_sentiment(b_ids, b_mask)
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
loss_sampled = F.cross_entropy(logits, y_sample.view(-1))
loss_sampled.backward()

# Potentially: Clip gradients using
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.update_hessian(bs=args.batch_size)
optimizer.zero_grad(set_to_none=True)
optimizer.update_hessian()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)

if args.scheduler == "cosine":
scheduler.step(epoch + num_batches / total_num_batches)
Expand Down Expand Up @@ -560,7 +529,7 @@ def get_args():
"--lr",
type=float,
help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
ddefault=1e-5 if args.option == "finetune" else 1e-3,
default=1e-5 if args.option == "finetune" else 1e-3,
)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--local_files_only", action="store_true")
Expand Down
14 changes: 8 additions & 6 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def __init__(
weight_decay=weight_decay,
eps=eps,
)
super(SophiaG, self).__init__(params, defaults)
super(SophiaH, self).__init__(params, defaults)


@torch.no_grad()
def update_hessian(self):
for group in self.param_groups:
_, beta2 = group["betas"]
Expand All @@ -274,14 +274,16 @@ def update_hessian(self):
continue
state = self.state[p]

gradient = p.grad.clone().detach().requires_grad_(True)

# draw u from N(0, I)
u = torch.randn_like(p.grad)
u = torch.randn_like(gradient)

# Compute < grad, u >
gu = torch.sum(p.grad * u)

gu = torch.matmul(gradient.view(-1), u.view(-1))
# Differentiate < grad, u > wrt to the parameters
hvp = torch.autograd.grad(gu, p, retain_graph=True)
hvp = torch.autograd.grad(gu, gradient, retain_graph=True)[0]

# u ⊙ hvp
state["hessian"].mul_(beta2).addcmul_(u, hvp, value=1 - beta2)
Expand Down

0 comments on commit 62bf578

Please sign in to comment.