Skip to content

Commit

Permalink
Add factored second moment (defaults to True)
Browse files Browse the repository at this point in the history
* Adapted from this pull request: konstmish/prodigy#25
  • Loading branch information
LoganBooker committed Nov 4, 2024
1 parent 0a52967 commit baebb53
Showing 1 changed file with 42 additions and 15 deletions.
57 changes: 42 additions & 15 deletions prodigyplus/prodigy_plus_schedulefree.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class ProdigyPlusScheduleFree(torch.optim.Optimizer):
adam_atan2 (boolean):
Use atan2 rather than epsilon and division for parameter updates (https://arxiv.org/abs/2407.05872).
Not compatible with StableAdamW. (default True)
factored (boolean):
Use factored approximation of the second moment, similar to Adafactor. Reduces memory usage.
(default True)
"""
def __init__(self, params, lr=1.0,
use_schedulefree=True,
Expand All @@ -109,7 +112,8 @@ def __init__(self, params, lr=1.0,
split_groups=True,
slice_p=10,
bf16_state=False,
adam_atan2=True):
adam_atan2=True,
factored=True):

if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
Expand Down Expand Up @@ -142,7 +146,8 @@ def __init__(self, params, lr=1.0,
use_bias_correction=use_bias_correction,
d_numerator=0.0,
bf16_state=bf16_state,
adam_atan2=adam_atan2)
adam_atan2=adam_atan2,
factored=factored)

self.d0 = d0
self.split_groups = split_groups
Expand Down Expand Up @@ -188,6 +193,11 @@ def supports_memory_efficient_fp16(self):
def supports_flat_params(self):
return True

def approx_sqrt(self, row, col):
r_factor = (row / row.mean(dim=-1, keepdim=True)).sqrt_().unsqueeze(-1)
c_factor = col.unsqueeze(-2).sqrt()
return torch.mul(r_factor, c_factor)

def get_sliced_tensor(self, tensor, slice_p):
# Downsample the tensor by using only a portion of parameters.
flat_tensor = tensor.ravel()
Expand All @@ -205,18 +215,24 @@ def get_sliced_tensor(self, tensor, slice_p):
# sliced_tensor = torch.as_strided(flattened_tensor, size=(numel,), stride=stride)
# return sliced_tensor

def initialise_state(self, p, state, slice_p, bf16_state):
if len(state) > 0:
def initialise_state(self, p, state, slice_p, bf16_state, factored):
if p.grad is None or len(state) != 0:
return

grad = p.grad.data
sliced_data = self.get_sliced_tensor(p.data, slice_p)

# z is exp_avg when schedule-free is disabled.
if self.use_schedulefree:
state['z'] = p.data.clone().detach()
else:
state['z'] = torch.zeros_like(p.data).detach()
state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

if factored and grad.dim() > 1:
state["exp_avg_sq_row"] = grad.new_zeros(grad.shape[:-1]).detach()
state["exp_avg_sq_col"] = grad.new_zeros(grad.shape[:-2] + grad.shape[-1:]).detach()
else:
state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

# If the initial weights are zero, don't bother storing them.
if p.data.count_nonzero() > 0:
Expand Down Expand Up @@ -278,6 +294,7 @@ def step(self, closure=None):
slice_p = group['slice_p']
bf16_state = group['bf16_state']
adam_atan2 = group['adam_atan2']
factored = group['factored']

if beta3 is None:
beta3 = beta2 ** 0.5
Expand Down Expand Up @@ -322,7 +339,13 @@ def step(self, closure=None):
# Adam EMA updates
if not self.use_schedulefree:
state['z'].mul_(beta1).add_(grad, alpha=d * (1 - beta1))
state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2))

if factored and grad.dim() > 1:
grad_sq = grad.square().add_(1e-30)
state["exp_avg_sq_row"].mul_(beta2).add_(grad_sq.mean(dim=-1), alpha=d * d * (1 - beta2))
state["exp_avg_sq_col"].mul_(beta2).add_(grad_sq.mean(dim=-2), alpha=d * d * (1 - beta2))
else:
state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2))

d_numerator_accum.add_(torch.dot(sliced_grad, state['p0'] - sliced_data), alpha=d_update)

Expand Down Expand Up @@ -353,14 +376,16 @@ def step(self, closure=None):
state = self.state[p]

exp_avg = state['z']
exp_avg_sq = state['exp_avg_sq']

if factored and len(grad.shape) >= 2:
denom = self.approx_sqrt(state["exp_avg_sq_row"], state["exp_avg_sq_col"])
else:
denom = state['exp_avg_sq'].sqrt()

if adam_atan2:
denom = exp_avg_sq.sqrt()
update = exp_avg.atan2(denom).mul(one_over_pi)
update = exp_avg.atan2(denom).mul_(one_over_pi)
else:
denom = exp_avg_sq.sqrt().add_(d * eps)
update = exp_avg.div(denom)
update = exp_avg.div(denom.add_(d * eps))

# StableAdamW
rms = grad.pow(2).div_(denom).mean().sqrt()
Expand All @@ -382,14 +407,16 @@ def step(self, closure=None):
state = self.state[p]

z = state['z']
exp_avg_sq = state['exp_avg_sq']

if factored and grad.dim() > 1:
denom = self.approx_sqrt(state["exp_avg_sq_row"], state["exp_avg_sq_col"])
else:
denom = state['exp_avg_sq'].sqrt()

if adam_atan2:
denom = exp_avg_sq.sqrt()
update = grad.atan2(denom).mul_(one_over_pi)
else:
denom = exp_avg_sq.sqrt().add_(d * eps)
update = grad.div(denom).mul_(d)
update = grad.div(denom.add_(d * eps)).mul_(d)

# StableAdamW.
rms = update.pow(2).mean().sqrt()
Expand Down

0 comments on commit baebb53

Please sign in to comment.