Skip to content

Commit

Permalink
enh: Sophia minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 21, 2023
1 parent de72912 commit 1087a67
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def train(args):

lr = args.lr
# optimizer = AdamW(model.parameters(), lr=lr)
optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.03, weight_decay=0.0)
optimizer = SophiaG(model.parameters(), lr=lr, eps=1e-12, rho=0.03, weight_decay=0.0)
hess_interval = 10
iter_num = 0

Expand Down
19 changes: 12 additions & 7 deletions optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
betas: Tuple[float, float] = (0.965, 0.99),
rho: float = 0.04,
weight_decay: float = 0.1,
eps: float = 1e-15,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
Expand All @@ -141,11 +142,14 @@ def __init__(
raise ValueError(
"Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay)
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
weight_decay=weight_decay,
eps=eps,
)
super(SophiaG, self).__init__(params, defaults)

Expand Down Expand Up @@ -173,28 +177,29 @@ def step(self, closure: Callable = None):

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

grad = p.grad

if grad is None:
continue

if grad.is_sparse:
raise RuntimeError("Sophia does not support sparse gradients")

# State should be stored in this dictionary
state = self.state[p]
device = p.device

# Init state variables
if len(state) == 0:
state["step"] = torch.zeros((1,), dtype=torch.float, device=device)
state["step"] = torch.zeros((1,), dtype=torch.float, device=p.device)
state["exp_avg"] = torch.zeros_like(p)
state["hessian"] = torch.zeros_like(p)

# Access hyperparameters from the `group` dictionary
beta1, _ = group["betas"]
rho = group["rho"]
lr = group["lr"]
eps = group["eps"]
weight_decay = group["weight_decay"]
exp_avg = state["exp_avg"]
hess = state["hessian"]
Expand All @@ -206,11 +211,11 @@ def step(self, closure: Callable = None):
p.data.mul_(1 - lr * weight_decay)

# 2 - Decay the first and second moment running average coefficient
state["exp_avg"].mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

# 3 - Decay the hessian running average coefficient
# Clipping the hessian.
ratio = (state["exp_avg"] / (rho * hess + 1e-15)).clamp(-1, 1)
p.data.add_(- lr * ratio)
ratio = (exp_avg / (rho * hess + eps)).clamp(-1, 1)
p.data.add_(-lr * ratio)

return loss

0 comments on commit 1087a67

Please sign in to comment.