-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
7529 do not disable autocast for cuda devices #7530
7529 do not disable autocast for cuda devices #7530
Conversation
…tocast implementation makes it a non-issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking care of it!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
i'm going to switch from this method to contextlib's nullcontext, as local testing with nightly pytorch indicates that'll avoid some new errors that will come down the pipe, and specifically the null context will allow us to entirely bypass all issues with autocast - sometimes, the platform is disabled incorrectly, or has partial support. the nullcontext will allow us to fully decide when to commit to providing autocast on a new platform |
@sayakpaul ready :-) |
@pcuenca i tried to unify the behaviour of the context selection to use contextlib |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👌
Can you also remove the error test in the instruct pix2pix xl pipeline 🙏 : Lines 927 to 930 in 5266ab7
|
@simbrams done |
@sayakpaul @DN6 @yiyixuxu i've updated every pipeline that uses 🤗 Accelerate to disable AMP so that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me now. I mainly checked the pipeline and the train script for sdxl lora.
@bghira thanks for the changes. A couple of comments:
So, IMO, we should make sure:
|
I didn't see it in this PR though |
agree with this |
@yiyixuxu see here: https://github.com/huggingface/diffusers/pull/7530/files#diff-942ed0e4f11fae99750e9b042c4950baa2078bf7f75b1a450a7fa0339f5f034bR551 ![]() |
that was already using a context manager, i simply updated it to be consistent |
Okay good to know. Thanks! What are our blockers now? I am happy to cook up a context manager in |
also, the instruct pix2pix script is where a lot of the original confusion came from! that you wrote :D there, the autocast disables for fp16 and pulls the device type str to chop |
i've been trying to chase down that training issue. i want to try running SD 1.5 and see if the problem for legacy models is limited to 2.x, in which case it'll be best to limit training to SD 1.x.. will update shortly |
I see it's enabled when mixed precision is FP16:
But anyway, as mentioned I am down to writing a context manager to help ease things a bit. Maybe providing everything you have discovered so far regarding training on Silicon could be a very nice resource. |
on mps,
i've gone through them a lot today, and really not sure why it would cause the MPS crash, it's incredibly unfortunate. the SDXL error does seem like it can be overcome, i just need to put a bit more time into hunting it down so far, this patch works as advertised at least to resolve the cuda-specific regression and it should probably be merged |
That's quite informative indeed.
Curious, what do you do on |
I will wait for @pcuenca to give it a shot too, before merging. |
in simpletuner i have a lot more brute-force general handling of dtypes and leave less responsibility for that on other parts of the stack. |
examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
Show resolved
Hide resolved
good morning. after some more checking, fp32 training works fine for MPS on SD 1.5 and 2.1. but this feels disappointing, as dtype issues just shouldn't be the only thing holding back memory-efficient training. but i guess it's all we can do for now? |
pytorch 2.3 supports MPS bf16: Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.bfloat16, torch.complex64, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool]
Invalid Types: [torch.float64, torch.float64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2] pytorch 2.2.1 does not: Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool]
Invalid Types: [torch.float64, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2]
---
[bfloat16] BFloat16 is not supported on MPS |
@bghira I am going to merge the PR in a while. But tomorrow, I will take a closer look into autocast related things that we're using in our training scripts and see if we can get rid of them. I can do these tests fast in a CUDA environment. Will look into an autocast context manager (to add to But just so I understand the issues for M3 training correctly, w.r.t #7530 (comment), were you able to pinpoint a bug when training with FP16? Does it fail during intermediate inference? |
@@ -752,6 +752,10 @@ def main(args): | |||
project_config=accelerator_project_config, | |||
) | |||
|
|||
# Disable AMP for MPS. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bghira possible to add a more descriptive comment here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I see this is not added to some scripts such as the advanced diffusion or consistency distillation scripts.
the fp16 mps bug happens on the third call to nn.Linear, the Parameter object is created with fp32 dtype for weight and bias but the query is fp16 |
its an actual training failure on step 1 when t=999 i saw that nn.Linear has a dtype parameter i tried setting but the Parameter class it gets passed into actually doesnt make use of the dtype parameter which seems like a torch bug |
Thanks very much folks for the insightful discussions on autocast and bfloat16. I am under the impression that PyTorch, indeed, has some weird voodoo going under the hood for these. Meanwhile, I am going to merge this PR to unblock our users. So, thanks to @bghira for the prompt action here. This thread still remains open for discussions around autocast and bfloat16 which I believe will be valuable for the community. |
without attention slicing, we see:
but with attention slicing, this error disappears. the resolution seems to be: # Base components to prepare
if torch.backends.mps.is_available():
accelerator.native_amp = False
results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders)
unet = results[0]
if torch.backends.mps.is_available():
unet.set_attention_slice() however, the casting type issue is now there for bf16 and fp16, but bf16 no longer crashes. still digging into it. |
The issues seem to be scoped to bfloat16, which is unfortunate because most hardware has moved to it over the past several years.
The second issue is from Jan 2022, so we may still be waiting for a while for a fix. It is clear that autocast was originally designed with fp16 and the fp16 grad scaler in mind and has issues with other types. |
so disabling autocast on bf16 everywhere does seem like the way to go. honestly the new optimiser solves any need to use fp16 - pure bf16 is lighter weight and simpler code |
* 7529 do not disable autocast for cuda devices * Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue * add autocast fix to other training examples * disable native_amp for dreambooth (sdxl) * disable native_amp for pix2pix (sdxl) * remove tests from remaining files * disable native_amp on huggingface accelerator for every training example that uses it * convert more usages of autocast to nullcontext, make style fixes * make style fixes * style. * Empty-Commit --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* 7529 do not disable autocast for cuda devices * Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue * add autocast fix to other training examples * disable native_amp for dreambooth (sdxl) * disable native_amp for pix2pix (sdxl) * remove tests from remaining files * disable native_amp on huggingface accelerator for every training example that uses it * convert more usages of autocast to nullcontext, make style fixes * make style fixes * style. * Empty-Commit --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
Fixes #7529
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.