Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more VRAM savings: no first moment and factored second moment #25

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

dxqbYD
Copy link
Contributor

@dxqbYD dxqbYD commented Oct 31, 2024

This is another PR with the intention of making Prodigy usable on low-vram systems and large models.

Combined with #22, the vram overhead by Prodigy can be reduced by roughly 95%.

Total VRAM usage by OneTrainer again as examples:

Model Prodigy Prodigy sliced&factored AdamW
Flux Dev LoRA Rank 128 12,1 GB 9,0 GB 10.5 GB
SDXL full unet finetune 31,9 GB 12,6 GB 21,7 GB

This is achieved by some of the concepts proposed in this paper https://arxiv.org/pdf/1804.04235. This PR therefore goes a bit beyond the scope of your paper.

I'd still propose to merge these PRs, because your repo is in wide production use, but limited by its large VRAM footprint.
You can always make a fork of the code, for people just interested in the code accompanying the paper.

This PR does 3 things:

The paper above has shown that the first EMA can be disabled and only the second moment be used, and still achieve good training results. Prodigy can already do that by setting beta1 to 0, but with this PR it also doesn't allocate any VRAM anymore towards the then-unused exp_avg state.

This saves 25% of vram.

if factored is set to True, it doesn't store a full tensor for the second EMA anymore, but a factored approximation. Details can be found in the paper above.

Depending on the shape of the model's parameters, this saves another 24 - 25% of vram.

The approximation of the second moment above can introduce unwanted spikes in parameter updates, that are not caught by regular gradient clipping. If update_clip is set to True, it does another round of clipping after the EMA has been applied to the gradients.
This makes a clearly noticeable difference when using the above, but I suspect this might even be beneficial for Prodigy independently of these vram optimizations:
People have reported having problems with a high beta2 setting and Prodigy here:
#8

This could be caused by the rapid changes in learning rate by Prodigy, that increase the variance of gradients, while the second moment EMA is out-of-date and indicating still a low variance. At a high beta2 this can be the case for a long time.
Large parameter updates are the result as shown here: grafik

Update clipping scales down the parameter updates, not only the gradients before EMA.

The remaining 45% of 95% vram savings mentioned above are by #22

Note again: the code with default settings doesn't have any effect. Set the values mentioned above if you want to test this PR.
This PR is only the part described above. I have a merge of all 3 PRs in my github fork if you want to test everything together.

@Arcitec
Copy link
Contributor

Arcitec commented Oct 31, 2024

Wow, yet again INCREDIBLY good work! I agree that these improvements are important and that the paper's original code can be moved to a paper_code branch or similar. People want to actually use Prodigy and these improvements finally make it usable for regular consumers.

@adefazio
Copy link
Collaborator

adefazio commented Nov 1, 2024

Thank you for the pull request! Konstantin and I are looking to review this and other pull requests we have received very soon, hopefully in the next week or two. I think that an adafactor version is definitely very useful.

@dxqbYD dxqbYD mentioned this pull request Nov 3, 2024
LoganBooker added a commit to LoganBooker/prodigy-plus-schedule-free that referenced this pull request Nov 4, 2024
@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 4, 2024

two issue have been reported with this PR that need further analysis:

  1. on Stable Cascade, there are NaNs if factured==True
  2. in a test on SD3.5 Medium, training the text encoder only, Prodigy never finds a learning rate and remains on d0 forever if update_clip==1.0. This might be because the clipped gradients are added to (p - p0), but the updates to s, which is the demoninator in the calculation of d_hat, are unclipped. On other models there is not this problem though.

@adefazio
Copy link
Collaborator

adefazio commented Nov 5, 2024

Adafactor is used fairly often in practice, do the standard code-bases for it give any indication of how to resolve the issues your reporting? I will hold off merging until we have a better idea of what's going on. I'm a little suspect of clipping as a solution as it introduces additional parameters.

@betterftr
Copy link

In my test, I trained Stable Cascade (used OneTrainer) both unet and text encoder, full bf16, and I managed to avoid nans when I increased eps2 from 1e-30 to 1e-25, or by turning off factored

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 6, 2024

In my test, I trained Stable Cascade (used OneTrainer) both unet and text encoder, full bf16, and I managed to avoid nans when I increased eps2 from 1e-30 to 1e-25, or by turning off factored

did you run the same test on original Adafactor? did that work without changing eps2?
I think in original Adafactor the same parameter is just called eps.

If it works in original Adafactor but not with this PR, I should reproduce and analyze it.

@betterftr
Copy link

betterftr commented Nov 6, 2024

Yes it worked:

Og Adafactor (Cascade):
image

(Cascade) your prodi with eps2 e30:
image

(Cascade) your prodi with eps2 e25:
image

and this was SD3.5L text encoder 1 (full bf16) with update_clip=1.0, dynamic lr didn't kick in:
image

and with update_clip=None
image

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 6, 2024

As a side note, because I was just thinking about what could be different between original Adafactor and this PR:

This PR does updates to exp_avg as AdamW does and as Prodigy does. This is different from how the original Adafactor code updates exp_avg, but I think this is an oversight in the implementation of Adafactor that the authors of the paper did not intend. You can find more details here: huggingface/transformers#34506

Quite sure that is not the reason for NaNs tough, even though you did use beta1 > 0 in your tests, so this code path is used.
I'll have to reproduce this. Thanks for providing this information.

@betterftr
Copy link

betterftr commented Nov 6, 2024

sorry I didn't provide screenshots for OT config for the Cascade tests, I set Beta1 to None in both, only in the SD3.5L tenc1 test I left if at 0.9 because at that time I was testing schedulefree edition

@konstmish
Copy link
Owner

  1. in a test on SD3.5 Medium, training the text encoder only, Prodigy never finds a learning rate and remains on d0 forever if update_clip==1.0. This might be because the clipped gradients are added to (p - p0), but the updates to s, which is the demoninator in the calculation of d_hat, are unclipped. On other models there is not this problem though.

I quickly looked at the code and I'm not sure what update_clip is doing. If I understood correctly, you clip the norm of the overall update rather than of the gradient, which is not something I've seen commonly used. Is there a good justification for that, like did it improve the performance/stability on some benchmarks? Standard optimization theory suggests that we should rather clip the gradients, and we should clip them in all expressions, which is ideally done by doing clipping outside of the optimizer and passing the clipped gradient to it.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 13, 2024

I quickly looked at the code and I'm not sure what update_clip is doing. If I understood correctly, you clip the norm of the overall update rather than of the gradient, which is not something I've seen commonly used. Is there a good justification for that, like did it improve the performance/stability on some benchmarks? Standard optimization theory suggests that we should rather clip the gradients, and we should clip them in all expressions, which is ideally done by doing clipping outside of the optimizer and passing the clipped gradient to it.

the Adafactor authors introduced it in their paper first here in section 6: https://arxiv.org/pdf/1804.04235
They explain their reasoning in section 5 as a solution to out-of-date exp_avg_sq. "out-of-date" meaning that the denominator of the parameter updates lags behind with a high beta2 value when the model evolves fast (or the LR is changes fast by Prodigy? see below).
Apparently it was later picked up by other authors and optimizers, who coined it "StableAdamW": https://arxiv.org/pdf/2304.13013

Anecdotally I could observe that training was worse when comparing original Adafactor with this implementation before I implemented update clipping, even when I had set d to a fixed learning rate as an experiment, to have a direct comparison. Samples during training of the Flux diffusion model seemed much less stable, but I don't have an objective benchmark. Looks like there might be a more objective benchmark in the second paper above, but I haven't looked at this in detail.
My test was with OneTrainer, which does gradient clipping to 1.0 by default, but still update clipping clearly had an (additional) effect.

This might also be why people have reported having issues with Prodigy and high values of beta2 here:
#8

This is just a guess, but it would make sense: when Prodigy scales up d, the numerator rises fast at beta1=0.9, but the denominator lags behind for a long time at beta2=0.999, because the Prodigy EMA is scaled by d - resulting in large parameter updates even with clipped gradients.

exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]

grad_sq=grad.square()+eps2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not necessary to create this tensor as we're only interested in the sums of squares over rows and columns. Perhaps we could compute them directly using, for instance, torch.norm by specifying dim and squaring the result?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, but isn't this a bet (x*y) square()s are slower than (x+y) sqrt()s + (x+y)square()s?
do we know that this is the case?

@konstmish
Copy link
Owner

Very interesting! Thank you for explaning the history behind this trick and sharing your experience with it.

The code looks mostly good to me, I think we can merge it after making a few fixes and testing the method. We should decide on the default value of eps2 and I don't think we can do that without running the optimizer on a few benchmarks.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 15, 2024

Very interesting! Thank you for explaning the history behind this trick and sharing your experience with it.

Just a question if you have the time, why did you choose to scale the EMAs by d?
I understand you have derived this theoretically in your paper (section 5), but when only looking at the result:

  • d can be considered constant over the vast majority of training time - then, scaling the EMAs by it doesn't have an effect
  • d only changes during ramp-up - and this is where it might cause problems, see above

The code looks mostly good to me, I think we can merge it after making a few fixes and testing the method. We should decide on the default value of eps2 and I don't think we can do that without running the optimizer on a few benchmarks.

the current default is taken from the Adafactor code, but I agree that further analysis is required because above someone has reported that they had to change eps2 to make it work on Stable Cascade (1e-25 instead of 1e-30).

Might take me a while to reproduce that because of a broken GPU :/

@konstmish
Copy link
Owner

Hey @dxqbYD, since this PR is on hold at least for now, could you make another pull request with just the change to remove exp_avg when beta1=0? I don't know if people use it that way, but it does make sense to support this option for memory efficiency. If not, just let me know and I'll create one myself.

Just a question if you have the time, why did you choose to scale the EMAs by d?

It's really just because of the theory, that's how we derived the method. Even though this multiplication shouldn't be important in the long run once d stabilizes, it still impacts the estimated value of d. The first stage of training when d is not yet stable seems to have a big impact on the final results of our method.

@dxqbYD dxqbYD mentioned this pull request Dec 16, 2024
@dxqbYD
Copy link
Contributor Author

dxqbYD commented Dec 16, 2024

Hey @dxqbYD, since this PR is on hold at least for now, could you make another pull request with just the change to remove exp_avg when beta1=0? I don't know if people use it that way, but it does make sense to support this option for memory efficiency. If not, just let me know and I'll create one myself.

#32

Not sure either for training, but at the very least you can use beta1==0 for estimating the LR. In the SD community, Prodigy is often used to estimate a rough LR, and then Adam is used for training and finetuning the LR.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Dec 16, 2024

Yes it worked:

Og Adafactor (Cascade): image

(Cascade) your prodi with eps2 e30: image

I was able to reproce this now,.
The issue is that eps2 is scaled by $d^2$:

grad_sq=grad.square()+eps2
exp_avg_sq_row.mul_(beta2).add_(grad_sq.mean(dim=-1), alpha=d * d * (1-beta2))

...while in the original Adafactor code, exp_avg_sq is not scaled scaled by $d^2$ and neither is therefore eps2.

I'll change the code above to

grad_sq=grad.square()
exp_avg_sq_row.mul_(beta2).add_(grad_sq.mean(dim=-1), alpha=d * d * (1-beta2)).add_(eps2)

Technically, add_(eps*(1-beta2)) is equivalent to the original Adafactor code, but I think this is close enough at 1e-30.

On the question why this only happened with Cascade, I guess that 0-gradients are quite rare.

@konstmish
Copy link
Owner

I want to have the adafactor version in the repo in the near future, so we can hopefully merge a variant of this PR soon. A few things that I need to understand:

  1. Scaling in the denominator can indeed cause trouble. When preparing the initial version of Prodigy, we simply added eps to the square root of exp_avg_sq when computing denom, which didn't work on some problems. The fix was to use d * eps, the reason it helps is simple: the update becomes proportional to $d_t^2 \frac{m_t}{\sqrt{v_t} + d_t\epsilon} = d_t\frac{m_t/d_t}{\sqrt{v_t/d_t^2} + \epsilon}$, which makes it more similar to the standard Adam update, eps is on the same scale as the gradients. Your last suggestion seems to go the other way by removing scaling eps2 by $d^2$, is that right?
  2. The Adafactor paper uses eps1=1e-3 as the default value, while our eps is set to 1e-8 by default, could that be the reason behind the instabilities?
  3. I tried comparing your implementation to the one in pytorch and there seem to be some differences in how epsilons are used. For instance, in pytorch, eps2 seems to be used to set alpha, which appears to be in line with the arxiv pseudocode, while in this PR eps2 is added to the gradients. Or is it just a different notation (so their eps1 is like your eps2)?
  4. It also appears to me that when we add eps2 to all squared gradients, summing the gradient components over all rows or columns, eps2 gets multiplied by their number. This seems in line with the adafactor paper, while the pytorch implementation uses clamping instead. I don't know if we want to have the implicit scaling by the dimension, nothing like that is used in Adam.
  5. I'm not sure at the moment if we should add it as a separate optimizer or keep adding things to Prodigy itself, do you have an opinion? There is a PR for adding the original implementation of Prodigy to the repo, so it's fine to have a complicated main implementation with lots of options. On the other hand, it might be more recognizable by others if we explicitly call it an adafactor version, and we might want to have different default values for some hyperparameters such as eps, so it'd be great to have some opinions.

@konstmish
Copy link
Owner

konstmish commented Dec 18, 2024

@adefazio
I think the way we compute $d_t$ should be changed when using Adafactor because it is a bit more similar to AdaGrad-Norm. Consider for instance the case where matrix $G_t$ is actually a column of length $p$, then Adafactor would update $V_t$ as $V_{t,ij} = \beta_2 V_{t-1,ij} + (1-\beta_2)|G_{t,ij}|\sqrt{\sum_{l=1}^p G_{t,il}^2}$, which is roughly the square root of the stepsizes of AdaGrad-Norm and Adam. So at least in this case, when computing the denominator in $d_t$, I think we should use $\sqrt{\Vert s_t\Vert_2\Vert s_t\Vert_1}$ instead of $\Vert s_t\Vert_1$. More generally, I think we want to use the $\sqrt{\Vert s_t \Vert_{\ell_2,\mathrm{rows}} \Vert s_t \Vert_{\ell_2,\mathrm{columns}}}$, that seems to give the right expression in the special cases.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Dec 18, 2024

My implementation is strictly based on the transformers implementation of Adafactor:
https://github.com/huggingface/transformers/blob/9613933b022ddbf085e2c593ed4ceea4c734179a/src/transformers/optimization.py#L672

I have no opinion on which implementation is better but assumed that transformers is the reference implementation and tried to change as little as possible.

About the epsilons, I might have added to an already confusing naming convention of multiple epsilons that aren't necessarily related:

  • Prodigy.eps, defaulting to 1e-8, does not exist in transformers.Adafactor.

  • transformers.Adafactor.eps[1] or $\epsilon_2$ in the paper, default 1e-3, is used in an alternative proposal to determine a LR (section 8 of the Adafactor paper). I have not implemented this, and don't think it's worth doing so.

  • transformers.Adafactor.eps[0] or $\epsilon_1$ in the paper, default 1e-30, is a very small value added to the factorized exp_avg_sq to avoid division by 0. transformers.Adafactor.eps[0] or $\epsilon_1$ in the paper is what I have called Prodigy.eps2.
    It was an issue above because if scaled by $d^2$, it went below the range of bf16 at 1e-38, causing div by 0 again. I don't think the exact value matters much, not even if it gets multiplied by the dimension of the the row/cols (ad 4), as long as it's small enough but not 0. I solved it by not scaling it by $d^2$, so we don't need a different default value than transformers.Adafactor which might cause additional confusion. Does PyTorch have a better way to avoid the hyperparameter completely?

I don't understand your point 1, but that might be because of the mixup of epsilons.

I'm not sure at the moment if we should add it as a separate optimizer or keep adding things to Prodigy itself, do you have an opinion? There is a PR for adding the original implementation of Prodigy to the repo, so it's fine to have a complicated main implementation with lots of options. On the other hand, it might be more recognizable by others if we explicitly call it an adafactor version, and we might want to have different default values for some hyperparameters such as eps, so it'd be great to have some opinions.

I don't have a strong opinion either way. Generally, it feels like we have 73 different optimizers with their own name each, but many of them could just be another hyperparameter to Adam. But if you want to go that route, testers on the OneTrainer discord have called it Prodifactor 🙂

@konstmish
Copy link
Owner

konstmish commented Dec 19, 2024

I have no opinion on which implementation is better but assumed that transformers is the reference implementation and tried to change as little as possible.

Transformers' implementation is fine, I don't think there is any differene as long as it works and it's readable. I'd only try to keep it consistent with the notation in the adafactor paper to the extent it's possible.

confusing naming convention of multiple epsilons that aren't necessarily related

Thanks, I got confused indeed. Ignore my point about the default value of eps.

Generally, it feels like we have 73 different optimizers with their own name each, but many of them could just be another hyperparameter to Adam.

True. I'm still somewhat inclined to use a different optimizer for Adafactor though as it requires a different way of estimating $d_t$.

testers on the OneTrainer discord have called it Prodifactor

Haha, that's a cool name :)

It was an issue above because if scaled by $d^2$, it went below the range of bf16 at 1e-38, causing div by 0 again.

Makes sense. In the current implementation of Prodigy, we tried to avoid having numbers that are too small by using the ratio of d / d0, but it's not perfect either because it can get too large. I think the time has come to revisit the implementation of update scaling. An equivalent way of implementing Prodigy is to use the updates $m_{t+1} = \beta_1 \frac{d_{t}}{d_{t+1}}m_t + (1-\beta_1)g_t$ and $v_{t+1} = \beta_2 \frac{d_{t}^2}{d_{t+1}^2}v_t + (1-\beta_2)g_t^2$, which has the same scale as the exponetial moving averages in Adam. I think we should make a new variant of Prodigy with this and test it, then we can do the same for Adafactor.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Dec 19, 2024

Transformers' implementation is fine, I don't think there is any differene as long as it works and it's readable. I'd only try to keep it consistent with the notation in the adafactor paper to the extent it's possible.

the only difference between the paper and the transformers implementation I've stumbled across is here:
huggingface/transformers#34506

but I wouldn't copy this to the Prodigy implementation. there is no theory behind this, other than that it is probably an oversight by the original authors.

Makes sense. In the current implementation of Prodigy, we tried to avoid having numbers that are too small by using the ratio of d / d0, but it's not perfect either because it can get too large. I think the time has come to revisit the implementation of update scaling. An equivalent way of implementing Prodigy is to use the updates m t + 1 = β 1 d t d t + 1 m t + ( 1 − β 1 ) g t and v t + 1 = β 2 d t 2 d t + 1 2 v t + ( 1 − β 1 ) g t 2 , which has the same scale as the exponetial moving averages in Adam. I think we should make a new variant of Prodigy with this and test it, then we can do the same for Adafactor.

If you do go ahead and re-test Prodigy it would be great if you could do so with only one main loop. The current implementation does the EMA updates in one loop, then d updates, then another loop to do the parameter updates.
I've turned it into just one main loop here, #28, but it is technically not exactly the same, so as @adefazio has proposed there it is the prudent thing to test it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants