diff --git a/optimizer.py b/optimizer.py index 01c780d..1fedc39 100644 --- a/optimizer.py +++ b/optimizer.py @@ -1,5 +1,6 @@ from calendar import c from hmac import new +from turtle import st from typing import Callable, Iterable, Tuple import math @@ -159,23 +160,23 @@ def update_hessian(self, bs: int): state = self.state[p] # B · ^g ⊙ ^g - new_hess = bs * torch.square(p.grad) - # Update the hessian estimate (moving average) - state["hessian"].mul_(beta2).add_(new_hess, alpha=1 - beta2) + state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=bs - bs * beta2) + @torch.no_grad() def step(self, closure: Callable = None): loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError("Sophia does not support sparse gradients")