Skip to content

Commit

Permalink
fix: small Sophia fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsNiklas committed Jun 20, 2023
1 parent f3b21aa commit 8ce9763
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from calendar import c
from hmac import new
from turtle import st
from typing import Callable, Iterable, Tuple
import math

Expand Down Expand Up @@ -159,23 +160,23 @@ def update_hessian(self, bs: int):
state = self.state[p]

# B · ^g ⊙ ^g
new_hess = bs * torch.square(p.grad)

# Update the hessian estimate (moving average)
state["hessian"].mul_(beta2).add_(new_hess, alpha=1 - beta2)
state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=bs - bs * beta2)

@torch.no_grad()
def step(self, closure: Callable = None):
loss = None

if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

grad = p.grad.data
grad = p.grad

if grad.is_sparse:
raise RuntimeError("Sophia does not support sparse gradients")
Expand Down

0 comments on commit 8ce9763

Please sign in to comment.