diff --git a/README.md b/README.md index 3d42161..5adf9e9 100644 --- a/README.md +++ b/README.md @@ -98,3 +98,10 @@ opt = Lion( year = {2019} } ``` + +```bibtex +@misc{Schaipp2024, + author = {Fabian Schaipp}, + url = {https://fabian-sp.github.io/posts/2024/02/decoupling/} +} +``` diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py index cdea027..9c9faca 100644 --- a/lion_pytorch/foreach.py +++ b/lion_pytorch/foreach.py @@ -17,12 +17,16 @@ def __init__( params, lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0 + weight_decay: float = 0.0, + decoupled_weight_decay: bool = False ): assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) assert all([hasattr(torch, attr) for attr in ('_foreach_mul_', '_foreach_add_', '_foreach_sign_', '_foreach_lerp_')]), 'this version of torch does not have the prerequisite foreach functions' + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + defaults = dict( lr = lr, betas = betas, @@ -44,7 +48,14 @@ def step( for group in self.param_groups: - lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas'] + lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # accumulate List[Tensor] for foreach inplace updates params = [] grads = [] diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py index 9813e65..4abc0b8 100644 --- a/lion_pytorch/lion_pytorch.py +++ b/lion_pytorch/lion_pytorch.py @@ -14,16 +14,16 @@ def exists(val): def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): # stepweight decay - p.data.mul_(1 - lr * wd) + p.data.mul_(1. - lr * wd) # weight update - update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_() + update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_() p.add_(update, alpha = -lr) # decay the momentum running average coefficient - exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2) + exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2) # class @@ -34,11 +34,15 @@ def __init__( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - use_triton: bool = False + use_triton: bool = False, + decoupled_weight_decay: bool = False, ): assert lr > 0. assert all([0. <= beta <= 1. for beta in betas]) + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + defaults = dict( lr = lr, betas = betas, @@ -67,7 +71,12 @@ def step( for group in self.param_groups: for p in filter(lambda p: exists(p.grad), group['params']): - grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] + grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr # init state - exponential moving average of gradient values diff --git a/setup.py b/setup.py index a34cf05..9074649 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lion-pytorch', packages = find_packages(exclude=[]), - version = '0.2.0', + version = '0.2.1', license='MIT', description = 'Lion Optimizer - Pytorch', author = 'Phil Wang',