diff --git a/classifier.py b/classifier.py index 8bcaaa0..7b52ad7 100644 --- a/classifier.py +++ b/classifier.py @@ -273,7 +273,7 @@ def train(args): lr = args.lr # optimizer = AdamW(model.parameters(), lr=lr) - optimizer = SophiaG(model.parameters(), lr=lr, betas=(0.965, 0.99), rho=0.03, weight_decay=0.0) + optimizer = SophiaG(model.parameters(), lr=lr, eps=1e-12, rho=0.03, weight_decay=0.0) hess_interval = 10 iter_num = 0 diff --git a/optimizer.py b/optimizer.py index 4476c1c..a8ba9ae 100644 --- a/optimizer.py +++ b/optimizer.py @@ -124,6 +124,7 @@ def __init__( betas: Tuple[float, float] = (0.965, 0.99), rho: float = 0.04, weight_decay: float = 0.1, + eps: float = 1e-15, ): if lr < 0.0: raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) @@ -141,11 +142,14 @@ def __init__( raise ValueError( "Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay) ) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) defaults = dict( lr=lr, betas=betas, rho=rho, weight_decay=weight_decay, + eps=eps, ) super(SophiaG, self).__init__(params, defaults) @@ -173,21 +177,21 @@ def step(self, closure: Callable = None): for group in self.param_groups: for p in group["params"]: - if p.grad is None: - continue grad = p.grad + if grad is None: + continue + if grad.is_sparse: raise RuntimeError("Sophia does not support sparse gradients") # State should be stored in this dictionary state = self.state[p] - device = p.device # Init state variables if len(state) == 0: - state["step"] = torch.zeros((1,), dtype=torch.float, device=device) + state["step"] = torch.zeros((1,), dtype=torch.float, device=p.device) state["exp_avg"] = torch.zeros_like(p) state["hessian"] = torch.zeros_like(p) @@ -195,6 +199,7 @@ def step(self, closure: Callable = None): beta1, _ = group["betas"] rho = group["rho"] lr = group["lr"] + eps = group["eps"] weight_decay = group["weight_decay"] exp_avg = state["exp_avg"] hess = state["hessian"] @@ -206,11 +211,11 @@ def step(self, closure: Callable = None): p.data.mul_(1 - lr * weight_decay) # 2 - Decay the first and second moment running average coefficient - state["exp_avg"].mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # 3 - Decay the hessian running average coefficient # Clipping the hessian. - ratio = (state["exp_avg"] / (rho * hess + 1e-15)).clamp(-1, 1) - p.data.add_(- lr * ratio) + ratio = (exp_avg / (rho * hess + eps)).clamp(-1, 1) + p.data.add_(-lr * ratio) return loss