From 63ab03f13a1826b658fa81835c43ef38209a7f1e Mon Sep 17 00:00:00 2001 From: dxqbYD Date: Mon, 16 Dec 2024 20:13:04 +0100 Subject: [PATCH 1/2] initial --- prodigyopt/prodigy.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/prodigyopt/prodigy.py b/prodigyopt/prodigy.py index 04e7f24..c082dc3 100644 --- a/prodigyopt/prodigy.py +++ b/prodigyopt/prodigy.py @@ -176,12 +176,13 @@ def step(self, closure=None): state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype) # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data).detach() + if beta1 > 0: + state['exp_avg'] = torch.zeros_like(p.data).detach() # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data).detach() - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - + exp_avg_sq = state['exp_avg_sq'] + s = state['s'] p0 = state['p0'] @@ -191,7 +192,9 @@ def step(self, closure=None): d_numerator += (d / d0) * dlr * torch.dot(sliced_grad, p0.data - p.data.flatten()[::slice_p]).item() # Adam EMA updates - exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1)) + if beta1 > 0: + exp_avg = state['exp_avg'] + exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1-beta2)) if safeguard_warmup: @@ -245,7 +248,7 @@ def step(self, closure=None): state = self.state[p] - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg_sq = state['exp_avg_sq'] state['step'] += 1 @@ -255,9 +258,13 @@ def step(self, closure=None): if decay != 0 and decouple: p.data.add_(p.data, alpha=-decay * dlr) - ### Take step - p.data.addcdiv_(exp_avg, denom, value=-dlr) + if beta1 > 0: + exp_avg = state['exp_avg'] + p.data.addcdiv_(exp_avg,denom, value=-dlr) + else: + p.data.addcdiv_(grad,denom, value=-dlr * d) + group['k'] = k + 1 From a416d0eb2d2845d58a08ac1d89331ccc1789a73e Mon Sep 17 00:00:00 2001 From: Konstantin Mishchenko Date: Wed, 18 Dec 2024 13:55:34 +0100 Subject: [PATCH 2/2] fix formatting --- prodigyopt/prodigy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/prodigyopt/prodigy.py b/prodigyopt/prodigy.py index 9158098..3adcc93 100644 --- a/prodigyopt/prodigy.py +++ b/prodigyopt/prodigy.py @@ -262,10 +262,9 @@ def step(self, closure=None): ### Take step if beta1 > 0: exp_avg = state['exp_avg'] - p.data.addcdiv_(exp_avg,denom, value=-dlr) + p.data.addcdiv_(exp_avg, denom, value=-dlr) else: - p.data.addcdiv_(grad,denom, value=-dlr * d) - + p.data.addcdiv_(grad, denom, value=-dlr * d) group['k'] = k + 1