-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mps] training / inference dtype issues #7563
Comments
if we can rely on the bf16 fixed AdamW optimiser, we can save storage space for the fp32 weights. overall, training becomes more efficient and reliable. thoughts? |
I can try to upload a pypi package of the fixed optimiser later today. It is located here: https://github.com/AmericanPresidentJimmyCarter/test-torch-bfloat16-vit-training/blob/main/adam_bfloat16/__init__.py#L22 |
Would it apply to float16 too? |
Cc: @patil-suraj too. |
i can try, but since the sd 2.1 model needs the attention up block's precision upcasted to at least bf16 i didn't think it were useful to test. it will produce black outputs without Xformers in use for SD 2.1 in particular, which @patrickvonplaten investigated here |
the other crappy thing is, without autocast, we have to use the same dtype for vae and unet. this is probably fine because the SD 1.x/2.x VAE handles fp16 like a champ. but the u-net requires at least bf16. manually modifying the unet dtype seems to break the mitigations put in place on the up block attn. |
fp16 is usually degraded and when you are training with it you need to use tricks like gradient scaling to work at all. For training scripts I am not sure we should recommend not using fp32 with fp16 autocast. The situation with bfloat16 is different and it seems with the correct optimiser it will always perform near equally to float32. The downside is that old devices may not support it. |
older CUDA devices emulate bf16, eg. the T4 on Colab. Apple MPS supports it with Pytorch 2.3 AMD ROCm i think is the outlier, but it also seems to have emulation. pinging @Beinsezii for clarification |
bf16 runs on ROCm with a 5-10% perf hit over fp16. Not sure of the implementation details but it's worked for everything I've needed it to. I only have an RDNA 3 card, so it may not work as cleanly on the older GPUs without WMMAs |
can't do fp16 on SD 2.1, as it goes into NaN at the 0th output. |
Pypi is here Repo is here |
unfortunately i hit a NaN at the 628th step of training, approximately the same place as before |
Ufff. Why that damn step? 😭 |
looks like it could be pytorch/pytorch#118115 as both optimizers in use that fail in this way do use addcdiv |
I will look into a fix for this too, if it's just |
it crashed after 628 steps and then on resume, it crashed after 300 steps, on the 901st. it also seems to get a lot slower than it should sometimes - but it was mentioned that could be heat-related. i doubt it's fully thermal problem, but it seems odd |
@sayakpaul you know what it ended up being is a cached latent with NaN values. i ran the SDXL VAE in fp16 mode since i was using pytorch 2.2 a few days ago, and that didn't support bf16. it worked pretty well, but i guess one or two of the files had corrupt outputs. so there's no inherent issue with backward pass on torch mps causing nan in 2.3. the one bug that i was observing in pytorch with a small reproducer is now backported to 2.3 as well, as of yesterday morning the torch compile is now more stable on 2.4, as well - so aot_eager makes a few annoyances with performance go away. so it's shaping up to be a fair level of support at this juncture for mps and i'll be able to work on that this weekend |
You were caching latents? Even with that how were there NaNs? VAE issue or something else? |
using the madebyollins sdxl vae fp16 model it occasionally NaNs, but not often enough to find the issue right away |
So many sudden glitches. |
on a new platform, the workarounds that are required for all platforms might not be added yet. eg. cuda handles type casting automatically, but mps requires strict types - any of the cuda workarounds for issues people saw >1 year ago are now forgotten. we have to essentially rediscover how cuda needed to work, and apply a lot of the same changes to MPS. i am removing fp16 training from my own trainer. fp32 is there, but i don't know why anyone would use it. pure bf16 with the fixed optimizer is the new hero here |
Thanks for investigating. I guess it’s just about time now fp16 support gets fixes. If people are aware of these findings I think it should still be okay. But fp16 inference — I don’t think we can throw that one out yet. |
fp16 inference is thrown out long ago
|
Will respectfully disagree. Not all the cards equally support bfloat16 well. |
the ones that don't are going to be upcasting about half of the information to fp32 - eg. the GT 1060 also doesn't support fp16. NVIDIA used to lock it behind a Quadro card purchase. in any case i don't think the cards that fail to support BF16 are useful for training. the Google Colab T4 emulates bf16 behind the scenes. which others are there? |
@bghira now that you fixed the latents issue, is bf16 training well with my optim? |
more than 790 steps without issue |
so i've been doing more extensive experimentation with mps and on any version of pytorch the experience is currently just very lacking
but despite that, something is still ruining the results for calculations on MPS. you can see this by trying to do inference on especially Stable Diffusion v1.5 or DeepFloyd using the MPS device.
i was able to train v-prediction/min-snr Dreambooth for Terminus using the MPS device in pure bf16 but it never looked quite accurate. but still, the results weren't completely awful: |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
when training on Diffusers without attention slicing, we see:
but with attention slicing, this error disappears.
however, once this issue is resolved, there is a new problem:
this is caused by the following logic:
which is done because the AdamW optimiser doesn't work with bf16 weights. however, thanks to @AmericanPresidentJimmyCarter we are now able to use the adamw_bfloat16 package to benefit from an optimizer that can handle pure bf16 training.
once that is the case, we can comment out the low precision warning code, and:
load the unet directly in the target precision level.
Originally posted by @bghira in #7530 (comment)
The text was updated successfully, but these errors were encountered: