Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) #7447

Merged
merged 6 commits into from
Mar 28, 2024

Conversation

bghira
Copy link
Contributor

@bghira bghira commented Mar 23, 2024

What does this PR do?

Adds MPS bugfixes to the SDXL example training scripts.

I haven't really completed this yet, I'm looking for comments on the approach, or whether it should be done at all.

On an M3 Max, I get reasonable speeds for training.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes are clean and minimal. I absolutely don't have any problems whatsoever in supporting this.

@sayakpaul sayakpaul requested a review from pcuenca March 25, 2024 11:00
@sayakpaul
Copy link
Member

@pcuenca could you give this a look as well?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira bghira force-pushed the bugfix/training-apple-mps branch 2 times, most recently from a1a777f to 2ed61d5 Compare March 25, 2024 13:16
@bghira
Copy link
Contributor Author

bghira commented Mar 25, 2024

@sayakpaul i've updated the rest of the SDXL training scripts:

  • we override --mixed_precision value before creating Accelerator
  • we disable autocast when using fp16, bf16, or MPS is enabled. this is to match the scripts to each other, as some already did this and others did not
  • the scripts were sometimes slightly different - some use log_validations as a method and others make the images inline

I wanted to make them all more consistent, but I wanted to limit the changes.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, thanks for this.

@bghira bghira force-pushed the bugfix/training-apple-mps branch 2 times, most recently from 1c7f7a9 to 709f8ac Compare March 25, 2024 15:04
@bghira
Copy link
Contributor Author

bghira commented Mar 25, 2024

@sayakpaul okay,

  • we'll toss an error when the user requests bf16 on mps
  • i've left it as torch.autocast since relying on the cuda-specific codepath feels incorrect while broadening the support

@bghira bghira changed the title apple mps: training support for SDXL LoRA apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) Mar 25, 2024
@sayakpaul
Copy link
Member

The changes are nice and clean. I am okay. @pcuenca to give this a look!

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far!

with inference_ctx:
with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=enable_autocast,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why do we disable autocast in this script but contextlib.nullcontext in others?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to peel the can of worms back further, we also need to use

            if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
                context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)

for deepspeed support.

it's starting to feel like we need an autocast manager wrapper in train utils

@bghira bghira force-pushed the bugfix/training-apple-mps branch 3 times, most recently from c9d483c to c9f3815 Compare March 25, 2024 16:32
@bghira bghira force-pushed the bugfix/training-apple-mps branch from c9f3815 to 52e46f1 Compare March 25, 2024 16:33
@bghira
Copy link
Contributor Author

bghira commented Mar 25, 2024

thanks for the help and feedback, this feels ready now.

@sayakpaul
Copy link
Member

I am good to proceed with the merge. @pcuenca could you also give this a final check?

@bghira
Copy link
Contributor Author

bghira commented Mar 25, 2024

i tested kohya's sd-scripts on apple mps and it only works with fp32 due to the same dtype bugs this PR addresses, remaining unfixed there. so this script ends up being a fairly solid implementation to reference if they wish to fix the same.

@sayakpaul
Copy link
Member

Slightly confused here. Are you saying you are not getting expected results on MPS with these changes and Kohya scripts yield expected results with potentially unfixed issues?

@bghira
Copy link
Contributor Author

bghira commented Mar 25, 2024

no - fp16 and fp32 training do the same thing, but MPS outputs differ from CPU outputs.

in general there are subtle correctness issues in pytorch's mps support that could add up to a different trained result. they aren't worse - just different

@aradley
Copy link

aradley commented Mar 27, 2024

I'm sorry, this is all beyond my expertise so apologies if this is an unnecessary comment, but I feel I must comment because you mentioned my bug report (#122233).

Above bghira says, "in general there are subtle correctness issues in pytorch's mps support that could add up to a different trained result. they aren't worse - just different". If your changes do not fix the issues in my report bug, then I profoundly disagree with this statement.

The bug that I have reported shows a clear example where MPS and CPU lead to errors in indexing with the torch.where function, which completely ruins all downstream calculations of the indexing error. If these "subtle correctness issues in pytorch's mps" are leaking into fundamental functions such as torch.where, this problem appears to be quite a critical bug that can not be waved away as "subtle differences" leading to "different but not worse" results.

@mallman
Copy link

mallman commented Mar 27, 2024

I reported one of the Pytorch MPS correctness bugs. That was well over a year ago, but IIRC the problem is that the MPS implementation of certain transcendental functions, like exp and log, are not as precise as the CPU implementations.

From my cursory examination, it appears this patch addresses these bugs by disabling MPS in certain areas. In that sense, this looks like less of a patch to support MPS and more like a patch to disable it. But then this statement is confusing:

in general there are subtle correctness issues in pytorch's mps support that could add up to a different trained result. they aren't worse - just different

It's true that MPS computations can lead to different trained results. Does this patch address that issue? That is, if I pass in MPS-stored tensors should I expect the same results as CPU-stored tensors?

I'd caution that statements like

they aren't worse - just different

are difficult to quantify and sound purely subjective. The correctness issues are bugs. They produce incorrect results. Recalling the principle of "garbage in garbage out", it's hard to see how using a runtime with flawed calculations can produce a scientifically/mathematically sound result.

But again, maybe I misunderstand this patch. Does it address the correctness issues? How does it do that? by disabling MPS?

@sayakpaul
Copy link
Member

Sorry it slipped through the cracks. Gonna merge after CI is green.

@sayakpaul sayakpaul merged commit d78acde into huggingface:main Mar 28, 2024
8 checks passed
@bghira

This comment was marked as off-topic.

@bghira bghira deleted the bugfix/training-apple-mps branch March 28, 2024 13:16
@mallman
Copy link

mallman commented Mar 28, 2024

I reported one of the Pytorch MPS correctness bugs. That was well over a year ago, but IIRC the problem is that the MPS implementation of certain transcendental functions, like exp and log, are not as precise as the CPU implementations.
From my cursory examination, it appears this patch addresses these bugs by disabling MPS in certain areas. In that sense, this looks like less of a patch to support MPS and more like a patch to disable it.

sorry, i might have missed something, how is this patch disabling mps? it is ensuring all dtypes are the same as mps requires.

I was looking at things like the code that throws errors for bf16 when mps is enabled. You're right, this doesn't disable mps, it throws errors where it is unsupported.

At this point, my remaining concern is around the difference in output from cpu versus mps. Should end users know to expect this? And what should they know? Should the message be that "mps output is different but valid" or something else like "mps output is different and invalid"? What are the consequences of the imprecision of the "fast math" functions in MPS used here?

@aradley
Copy link

aradley commented Mar 28, 2024

I'm not sure what you mean by "the downstream math becoming meaningless could be simply a human construct". The issue I've highlighted is not one of deterministic Vs non-determanistic, rather, I have presented a clear issue where MPS is wrong. This is not a human construct, the MPS based torch.where function literally gives an incorrect output.

If your patch is not trying to rectify this issue, then fine, but you did mention specifically cite my bug report (#122233).

However, what I would say is that if something as simple as as torch.where on MPS is giving incorrect results, I don't see how you can confidently claim that the results on MPS are "different but not wrong" without some comprehensive evidence to back up that claim. On what basis do you claim that the differences observed are down to small errors as opposed to a critical error, as in the torch.where example I present?

@bghira
Copy link
Contributor Author

bghira commented Mar 28, 2024

i've removed the links to pytorch bugs to reduce the occurrence of comments about upstream issues we're not able to resolve. this isn't the place for discussing it, as frustrating as it may be that upstream does not address them.

noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
…2I) (huggingface#7447)

* apple mps: training support for SDXL LoRA

* sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…2I) (#7447)

* apple mps: training support for SDXL LoRA

* sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet on apple mps

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants