Skip to content

Commit

Permalink
feat: Sophia step() alpha version
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 16, 2023
1 parent 17bbbc7 commit 492f9f3
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 10 deletions.
20 changes: 17 additions & 3 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# change it with respect to the original model
from tokenizer import BertTokenizer
from bert import BertModel
from optimizer import AdamW
from optimizer import AdamW, SophiaG
from tqdm import tqdm

TQDM_DISABLE = False
Expand Down Expand Up @@ -259,7 +259,11 @@ def train(args):
model = model.to(device)

lr = args.lr
optimizer = AdamW(model.parameters(), lr=lr)
#optimizer = AdamW(model.parameters(), lr=lr)
optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.01, weight_decay=0.01)
k = 10
iter_num = 0

best_dev_acc = 0

# Run for the specified number of epochs
Expand All @@ -280,11 +284,21 @@ def train(args):
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size

loss.backward()
optimizer.step()
optimizer.step(bs = args.batch_size)

train_loss += loss.item()
num_batches += 1

if iter_num % k != k - 1:
logits = model(b_ids, torch.ones_like(b_mask))
samp_dist = torch.distributions.Categorical(logits=logits)
y_sample = samp_dist.sample()
loss_sampled = F.cross_entropy(logits, y_sample.view(-1), reduction='sum') / args.batch_size
loss_sampled.backward()
optimizer.update_hessian()

iter_num += 1

train_loss = train_loss / (num_batches)

train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
Expand Down
73 changes: 66 additions & 7 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from calendar import c
from typing import Callable, Iterable, Tuple
import math

Expand Down Expand Up @@ -56,8 +57,7 @@ def step(self, closure: Callable = None):

# State should be stored in this dictionary
state = self.state[p]

device = grad.device
device = p.device

# Access hyperparameters from the `group` dictionary
alpha = group["lr"]
Expand All @@ -67,7 +67,6 @@ def step(self, closure: Callable = None):
correct_bias = group["correct_bias"]

# Init state variables

if "t" not in state:
state["t"] = torch.tensor([0]).to(device)

Expand All @@ -91,7 +90,6 @@ def step(self, closure: Callable = None):
# (they are lr, betas, eps, weight_decay, as saved in the constructor).

# 1- Update first and second moments of the gradients

state["m"] = beta_1 * state["m"] + (1 - beta_1) * grad
state["v"] = beta_2 * state["v"] + (1 - beta_2) * torch.square(grad)

Expand All @@ -106,12 +104,10 @@ def step(self, closure: Callable = None):
)

# 3- Update parameters (p.data).

p.data = p.data - alpha * state["m"] / (torch.sqrt(state["v"]) + eps)

# 4- After that main gradient-based update, update again using weight decay
# (incorporating the learning rate again).

p.data = p.data - group["lr"] * p.data * weight_decay

return loss
Expand Down Expand Up @@ -150,5 +146,68 @@ def __init__(
)
super(SophiaG, self).__init__(params, defaults)

@torch.no_grad()
def update_hessian(self):
for group in self.param_groups:
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]

state["hessian"] = beta2 * state["hessian"] + (
1 - beta2
) * torch.square(p.grad.data)

def step(self, closure: Callable = None, bs: int = 5120):
pass
loss = None

if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

grad = p.grad.data

if grad.is_sparse:
raise RuntimeError("Sophia does not support sparse gradients")

# State should be stored in this dictionary
state = self.state[p]
device = p.device

# Init state variables
if len(state) == 0:
state["step"] = torch.zeros((1,), dtype=torch.float, device=device)
state["exp_avg"] = torch.zeros_like(p)
state["hessian"] = torch.zeros_like(p)

# Access hyperparameters from the `group` dictionary
beta1, beta2 = group["betas"]
rho = group["rho"]
exp_avg = state["exp_avg"]
lr = group["lr"]
weight_decay = group["weight_decay"]

# Calculation of new weights
state["step"] += 1

# 1 - Perform stepweight decay
p.data = p.data - p.data * lr * weight_decay

# 2 - Decay the first and second moment running average coefficient
state["exp_avg"] = beta1 * exp_avg + (1 - beta1) * grad

# 3 - Decay the hessian running average coefficient
step_size_neg = -lr

ratio = (exp_avg.abs() / (rho * bs * state["hessian"] + 1e-15)).clamp(
None, 1
)

p.data = p.data + exp_avg.sign() * ratio * step_size_neg

return loss

0 comments on commit 492f9f3

Please sign in to comment.