Skip to content

Commit

Permalink
Merge pull request #32 from dxqbYD/beta0
Browse files Browse the repository at this point in the history
No EMA buffer at beta1==0
  • Loading branch information
konstmish authored Dec 18, 2024
2 parents 34fe93e + a416d0e commit bb88fb8
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions prodigyopt/prodigy.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def step(self, closure=None):
state['p0'] = torch.tensor(0, device=p.device, dtype=p.dtype)

# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data).detach()
if beta1 > 0:
state['exp_avg'] = torch.zeros_like(p.data).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg_sq = state['exp_avg_sq']

s = state['s']
p0 = state['p0']

Expand All @@ -192,7 +193,9 @@ def step(self, closure=None):
d_numerator += (d / d0) * dlr * torch.dot(sliced_grad, p0.data - p.data.flatten()[::slice_p]).item()

# Adam EMA updates
exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1))
if beta1 > 0:
exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1-beta2))

if safeguard_warmup:
Expand Down Expand Up @@ -246,7 +249,7 @@ def step(self, closure=None):

state = self.state[p]

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg_sq = state['exp_avg_sq']

state['step'] += 1

Expand All @@ -256,9 +259,12 @@ def step(self, closure=None):
if decay != 0 and decouple:
p.data.add_(p.data, alpha=-decay * dlr)


### Take step
p.data.addcdiv_(exp_avg, denom, value=-dlr)
if beta1 > 0:
exp_avg = state['exp_avg']
p.data.addcdiv_(exp_avg, denom, value=-dlr)
else:
p.data.addcdiv_(grad, denom, value=-dlr * d)

group['k'] = k + 1

Expand Down

0 comments on commit bb88fb8

Please sign in to comment.