-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) #7447
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes are clean and minimal. I absolutely don't have any problems whatsoever in supporting this.
@pcuenca could you give this a look as well? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
a1a777f
to
2ed61d5
Compare
@sayakpaul i've updated the rest of the SDXL training scripts:
I wanted to make them all more consistent, but I wanted to limit the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, thanks for this.
1c7f7a9
to
709f8ac
Compare
@sayakpaul okay,
|
The changes are nice and clean. I am okay. @pcuenca to give this a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good so far!
with inference_ctx: | ||
with torch.autocast( | ||
str(accelerator.device).replace(":0", ""), | ||
enabled=enable_autocast, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why do we disable autocast
in this script but contextlib.nullcontext
in others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is on me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want to peel the can of worms back further, we also need to use
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
for deepspeed support.
it's starting to feel like we need an autocast manager wrapper in train utils
c9d483c
to
c9f3815
Compare
c9f3815
to
52e46f1
Compare
thanks for the help and feedback, this feels ready now. |
I am good to proceed with the merge. @pcuenca could you also give this a final check? |
i tested kohya's sd-scripts on apple mps and it only works with fp32 due to the same dtype bugs this PR addresses, remaining unfixed there. so this script ends up being a fairly solid implementation to reference if they wish to fix the same. |
Slightly confused here. Are you saying you are not getting expected results on MPS with these changes and Kohya scripts yield expected results with potentially unfixed issues? |
no - fp16 and fp32 training do the same thing, but MPS outputs differ from CPU outputs. in general there are subtle correctness issues in pytorch's mps support that could add up to a different trained result. they aren't worse - just different |
I'm sorry, this is all beyond my expertise so apologies if this is an unnecessary comment, but I feel I must comment because you mentioned my bug report (#122233). Above bghira says, "in general there are subtle correctness issues in pytorch's mps support that could add up to a different trained result. they aren't worse - just different". If your changes do not fix the issues in my report bug, then I profoundly disagree with this statement. The bug that I have reported shows a clear example where MPS and CPU lead to errors in indexing with the torch.where function, which completely ruins all downstream calculations of the indexing error. If these "subtle correctness issues in pytorch's mps" are leaking into fundamental functions such as torch.where, this problem appears to be quite a critical bug that can not be waved away as "subtle differences" leading to "different but not worse" results. |
I reported one of the Pytorch MPS correctness bugs. That was well over a year ago, but IIRC the problem is that the MPS implementation of certain transcendental functions, like From my cursory examination, it appears this patch addresses these bugs by disabling MPS in certain areas. In that sense, this looks like less of a patch to support MPS and more like a patch to disable it. But then this statement is confusing:
It's true that MPS computations can lead to different trained results. Does this patch address that issue? That is, if I pass in MPS-stored tensors should I expect the same results as CPU-stored tensors? I'd caution that statements like
are difficult to quantify and sound purely subjective. The correctness issues are bugs. They produce incorrect results. Recalling the principle of "garbage in garbage out", it's hard to see how using a runtime with flawed calculations can produce a scientifically/mathematically sound result. But again, maybe I misunderstand this patch. Does it address the correctness issues? How does it do that? by disabling MPS? |
Sorry it slipped through the cracks. Gonna merge after CI is green. |
This comment was marked as off-topic.
This comment was marked as off-topic.
I was looking at things like the code that throws errors for bf16 when mps is enabled. You're right, this doesn't disable mps, it throws errors where it is unsupported. At this point, my remaining concern is around the difference in output from cpu versus mps. Should end users know to expect this? And what should they know? Should the message be that "mps output is different but valid" or something else like "mps output is different and invalid"? What are the consequences of the imprecision of the "fast math" functions in MPS used here? |
I'm not sure what you mean by "the downstream math becoming meaningless could be simply a human construct". The issue I've highlighted is not one of deterministic Vs non-determanistic, rather, I have presented a clear issue where MPS is wrong. This is not a human construct, the MPS based torch.where function literally gives an incorrect output. If your patch is not trying to rectify this issue, then fine, but you did mention specifically cite my bug report (#122233). However, what I would say is that if something as simple as as torch.where on MPS is giving incorrect results, I don't see how you can confidently claim that the results on MPS are "different but not wrong" without some comprehensive evidence to back up that claim. On what basis do you claim that the differences observed are down to small errors as opposed to a critical error, as in the torch.where example I present? |
i've removed the links to pytorch bugs to reduce the occurrence of comments about upstream issues we're not able to resolve. this isn't the place for discussing it, as frustrating as it may be that upstream does not address them. |
…2I) (huggingface#7447) * apple mps: training support for SDXL LoRA * sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
…2I) (#7447) * apple mps: training support for SDXL LoRA * sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
Adds MPS bugfixes to the SDXL example training scripts.
I haven't really completed this yet, I'm looking for comments on the approach, or whether it should be done at all.
On an M3 Max, I get reasonable speeds for training.