From baebb53cdcb9307787630b6b92f0a66f5bc3f17b Mon Sep 17 00:00:00 2001 From: Logan Date: Mon, 4 Nov 2024 14:36:44 +1100 Subject: [PATCH] Add factored second moment (defaults to True) * Adapted from this pull request: https://github.com/konstmish/prodigy/pull/25 --- prodigyplus/prodigy_plus_schedulefree.py | 57 +++++++++++++++++------- 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/prodigyplus/prodigy_plus_schedulefree.py b/prodigyplus/prodigy_plus_schedulefree.py index 4c92874..fda767c 100644 --- a/prodigyplus/prodigy_plus_schedulefree.py +++ b/prodigyplus/prodigy_plus_schedulefree.py @@ -96,6 +96,9 @@ class ProdigyPlusScheduleFree(torch.optim.Optimizer): adam_atan2 (boolean): Use atan2 rather than epsilon and division for parameter updates (https://arxiv.org/abs/2407.05872). Not compatible with StableAdamW. (default True) + factored (boolean): + Use factored approximation of the second moment, similar to Adafactor. Reduces memory usage. + (default True) """ def __init__(self, params, lr=1.0, use_schedulefree=True, @@ -109,7 +112,8 @@ def __init__(self, params, lr=1.0, split_groups=True, slice_p=10, bf16_state=False, - adam_atan2=True): + adam_atan2=True, + factored=True): if not 0.0 < d0: raise ValueError("Invalid d0 value: {}".format(d0)) @@ -142,7 +146,8 @@ def __init__(self, params, lr=1.0, use_bias_correction=use_bias_correction, d_numerator=0.0, bf16_state=bf16_state, - adam_atan2=adam_atan2) + adam_atan2=adam_atan2, + factored=factored) self.d0 = d0 self.split_groups = split_groups @@ -188,6 +193,11 @@ def supports_memory_efficient_fp16(self): def supports_flat_params(self): return True + def approx_sqrt(self, row, col): + r_factor = (row / row.mean(dim=-1, keepdim=True)).sqrt_().unsqueeze(-1) + c_factor = col.unsqueeze(-2).sqrt() + return torch.mul(r_factor, c_factor) + def get_sliced_tensor(self, tensor, slice_p): # Downsample the tensor by using only a portion of parameters. flat_tensor = tensor.ravel() @@ -205,10 +215,11 @@ def get_sliced_tensor(self, tensor, slice_p): # sliced_tensor = torch.as_strided(flattened_tensor, size=(numel,), stride=stride) # return sliced_tensor - def initialise_state(self, p, state, slice_p, bf16_state): - if len(state) > 0: + def initialise_state(self, p, state, slice_p, bf16_state, factored): + if p.grad is None or len(state) != 0: return + grad = p.grad.data sliced_data = self.get_sliced_tensor(p.data, slice_p) # z is exp_avg when schedule-free is disabled. @@ -216,7 +227,12 @@ def initialise_state(self, p, state, slice_p, bf16_state): state['z'] = p.data.clone().detach() else: state['z'] = torch.zeros_like(p.data).detach() - state['exp_avg_sq'] = torch.zeros_like(p.data).detach() + + if factored and grad.dim() > 1: + state["exp_avg_sq_row"] = grad.new_zeros(grad.shape[:-1]).detach() + state["exp_avg_sq_col"] = grad.new_zeros(grad.shape[:-2] + grad.shape[-1:]).detach() + else: + state['exp_avg_sq'] = torch.zeros_like(p.data).detach() # If the initial weights are zero, don't bother storing them. if p.data.count_nonzero() > 0: @@ -278,6 +294,7 @@ def step(self, closure=None): slice_p = group['slice_p'] bf16_state = group['bf16_state'] adam_atan2 = group['adam_atan2'] + factored = group['factored'] if beta3 is None: beta3 = beta2 ** 0.5 @@ -322,7 +339,13 @@ def step(self, closure=None): # Adam EMA updates if not self.use_schedulefree: state['z'].mul_(beta1).add_(grad, alpha=d * (1 - beta1)) - state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2)) + + if factored and grad.dim() > 1: + grad_sq = grad.square().add_(1e-30) + state["exp_avg_sq_row"].mul_(beta2).add_(grad_sq.mean(dim=-1), alpha=d * d * (1 - beta2)) + state["exp_avg_sq_col"].mul_(beta2).add_(grad_sq.mean(dim=-2), alpha=d * d * (1 - beta2)) + else: + state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2)) d_numerator_accum.add_(torch.dot(sliced_grad, state['p0'] - sliced_data), alpha=d_update) @@ -353,14 +376,16 @@ def step(self, closure=None): state = self.state[p] exp_avg = state['z'] - exp_avg_sq = state['exp_avg_sq'] + + if factored and len(grad.shape) >= 2: + denom = self.approx_sqrt(state["exp_avg_sq_row"], state["exp_avg_sq_col"]) + else: + denom = state['exp_avg_sq'].sqrt() if adam_atan2: - denom = exp_avg_sq.sqrt() - update = exp_avg.atan2(denom).mul(one_over_pi) + update = exp_avg.atan2(denom).mul_(one_over_pi) else: - denom = exp_avg_sq.sqrt().add_(d * eps) - update = exp_avg.div(denom) + update = exp_avg.div(denom.add_(d * eps)) # StableAdamW rms = grad.pow(2).div_(denom).mean().sqrt() @@ -382,14 +407,16 @@ def step(self, closure=None): state = self.state[p] z = state['z'] - exp_avg_sq = state['exp_avg_sq'] + + if factored and grad.dim() > 1: + denom = self.approx_sqrt(state["exp_avg_sq_row"], state["exp_avg_sq_col"]) + else: + denom = state['exp_avg_sq'].sqrt() if adam_atan2: - denom = exp_avg_sq.sqrt() update = grad.atan2(denom).mul_(one_over_pi) else: - denom = exp_avg_sq.sqrt().add_(d * eps) - update = grad.div(denom).mul_(d) + update = grad.div(denom.add_(d * eps)).mul_(d) # StableAdamW. rms = update.pow(2).mean().sqrt()