Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mps] training / inference dtype issues #7563

Open
bghira opened this issue Apr 2, 2024 · 33 comments
Open

[mps] training / inference dtype issues #7563

bghira opened this issue Apr 2, 2024 · 33 comments
Labels
stale Issues that haven't received updates

Comments

@bghira
Copy link
Contributor

bghira commented Apr 2, 2024

when training on Diffusers without attention slicing, we see:

/AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:788: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'

but with attention slicing, this error disappears.

    # Base components to prepare
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders)
    unet = results[0]
    if torch.backends.mps.is_available():
        unet.set_attention_slice()

however, once this issue is resolved, there is a new problem:

    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

this is caused by the following logic:

    # Check that all trainable models are in full precision
    low_precision_error_string = (
        "Please make sure to always have all model weights in full float32 precision when starting training - even if"
        " doing mixed precision training. copy of the weights should still be float32."
    )

    if accelerator.unwrap_model(unet).dtype != torch.float32:
        raise ValueError(
            f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
        )

    if (
        args.train_text_encoder
        and accelerator.unwrap_model(text_encoder).dtype != torch.float32
    ):
        raise ValueError(
            f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
            f" {low_precision_error_string}"
        )

which is done because the AdamW optimiser doesn't work with bf16 weights. however, thanks to @AmericanPresidentJimmyCarter we are now able to use the adamw_bfloat16 package to benefit from an optimizer that can handle pure bf16 training.

once that is the case, we can comment out the low precision warning code, and:

    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

load the unet directly in the target precision level.

Originally posted by @bghira in #7530 (comment)

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

@sayakpaul @pcuenca @DN6

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

if we can rely on the bf16 fixed AdamW optimiser, we can save storage space for the fp32 weights. overall, training becomes more efficient and reliable. thoughts?

@AmericanPresidentJimmyCarter
Copy link
Contributor

I can try to upload a pypi package of the fixed optimiser later today. It is located here: https://github.com/AmericanPresidentJimmyCarter/test-torch-bfloat16-vit-training/blob/main/adam_bfloat16/__init__.py#L22

@sayakpaul
Copy link
Member

  unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

Would it apply to float16 too?

@sayakpaul
Copy link
Member

Cc: @patil-suraj too.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

  unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

Would it apply to float16 too?

i can try, but since the sd 2.1 model needs the attention up block's precision upcasted to at least bf16 i didn't think it were useful to test. it will produce black outputs without Xformers in use for SD 2.1 in particular, which @patrickvonplaten investigated here

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

the other crappy thing is, without autocast, we have to use the same dtype for vae and unet. this is probably fine because the SD 1.x/2.x VAE handles fp16 like a champ. but the u-net requires at least bf16. manually modifying the unet dtype seems to break the mitigations put in place on the up block attn.

@AmericanPresidentJimmyCarter
Copy link
Contributor

fp16 is usually degraded and when you are training with it you need to use tricks like gradient scaling to work at all. For training scripts I am not sure we should recommend not using fp32 with fp16 autocast.

The situation with bfloat16 is different and it seems with the correct optimiser it will always perform near equally to float32. The downside is that old devices may not support it.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

older CUDA devices emulate bf16, eg. the T4 on Colab.

Apple MPS supports it with Pytorch 2.3

AMD ROCm i think is the outlier, but it also seems to have emulation. pinging @Beinsezii for clarification

@Beinsezii
Copy link
Contributor

Beinsezii commented Apr 2, 2024

bf16 runs on ROCm with a 5-10% perf hit over fp16. Not sure of the implementation details but it's worked for everything I've needed it to.

I only have an RDNA 3 card, so it may not work as cleanly on the older GPUs without WMMAs

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

can't do fp16 on SD 2.1, as it goes into NaN at the 0th output.

@bghira
Copy link
Contributor Author

bghira commented Apr 3, 2024

something still not quite right for mps, maybe the lack of 8bit optimisers really hurts more than i'd think, haha.

we see sampling speed improvements up to bsz=8 and then it hits swap space on a 128G unit

image

@bghira
Copy link
Contributor Author

bghira commented Apr 3, 2024

image

@AmericanPresidentJimmyCarter
Copy link
Contributor

@bghira
Copy link
Contributor Author

bghira commented Apr 4, 2024

tested the above training implementation on (so far) 300 steps of ptx0/photo-concept-bucket at a decent learning rate and batch size of 4 on an apple m3 max

it's definitely learning.
image

compared to what unmodified optimizer does
image

i mean, it added human eyes to the tiger which might not be unexpected when zero images of tigers exist in the dataset, but it's certainly not what happens now with a fixed bf16 adamw implementation.

edit: importantly, the fixed optimizer does not (so far) run into NaNs unlike the currently-available selections of optimizers

@bghira
Copy link
Contributor Author

bghira commented Apr 5, 2024

unfortunately i hit a NaN at the 628th step of training, approximately the same place as before

@sayakpaul
Copy link
Member

Ufff. Why that damn step? 😭

@bghira
Copy link
Contributor Author

bghira commented Apr 5, 2024

looks like it could be pytorch/pytorch#118115 as both optimizers in use that fail in this way do use addcdiv

@AmericanPresidentJimmyCarter
Copy link
Contributor

looks like it could be pytorch/pytorch#118115 as both optimizers in use that fail in this way do use addcdiv

I will look into a fix for this too, if it's just .contiguous it should be easy.

@bghira
Copy link
Contributor Author

bghira commented Apr 5, 2024

it crashed after 628 steps and then on resume, it crashed after 300 steps, on the 901st.

it also seems to get a lot slower than it should sometimes - but it was mentioned that could be heat-related. i doubt it's fully thermal problem, but it seems odd

@bghira
Copy link
Contributor Author

bghira commented Apr 5, 2024

@sayakpaul you know what it ended up being is a cached latent with NaN values. i ran the SDXL VAE in fp16 mode since i was using pytorch 2.2 a few days ago, and that didn't support bf16. it worked pretty well, but i guess one or two of the files had corrupt outputs. so there's no inherent issue with backward pass on torch mps causing nan in 2.3. the one bug that i was observing in pytorch with a small reproducer is now backported to 2.3 as well, as of yesterday morning

the torch compile is now more stable on 2.4, as well - so aot_eager makes a few annoyances with performance go away.

so it's shaping up to be a fair level of support at this juncture for mps and i'll be able to work on that this weekend

@sayakpaul
Copy link
Member

you know what it ended up being is a cached latent with NaN values.

You were caching latents? Even with that how were there NaNs? VAE issue or something else?

@bghira
Copy link
Contributor Author

bghira commented Apr 6, 2024

using the madebyollins sdxl vae fp16 model it occasionally NaNs, but not often enough to find the issue right away

@sayakpaul
Copy link
Member

So many sudden glitches.

@bghira
Copy link
Contributor Author

bghira commented Apr 6, 2024

on a new platform, the workarounds that are required for all platforms might not be added yet.

eg. cuda handles type casting automatically, but mps requires strict types - any of the cuda workarounds for issues people saw >1 year ago are now forgotten. we have to essentially rediscover how cuda needed to work, and apply a lot of the same changes to MPS.

i am removing fp16 training from my own trainer. fp32 is there, but i don't know why anyone would use it.

pure bf16 with the fixed optimizer is the new hero here

@sayakpaul
Copy link
Member

Thanks for investigating. I guess it’s just about time now fp16 support gets fixes. If people are aware of these findings I think it should still be okay. But fp16 inference — I don’t think we can throw that one out yet.

@bghira
Copy link
Contributor Author

bghira commented Apr 6, 2024

fp16 inference is thrown out long ago

  • sdxl's vae doesn't work with it
  • sd 2.1's unet doesn't work with it

@sayakpaul
Copy link
Member

Will respectfully disagree. Not all the cards equally support bfloat16 well.

@bghira
Copy link
Contributor Author

bghira commented Apr 6, 2024

the ones that don't are going to be upcasting about half of the information to fp32 - eg. the GT 1060 also doesn't support fp16. NVIDIA used to lock it behind a Quadro card purchase.

in any case i don't think the cards that fail to support BF16 are useful for training.

the Google Colab T4 emulates bf16 behind the scenes. which others are there?

@AmericanPresidentJimmyCarter
Copy link
Contributor

@bghira now that you fixed the latents issue, is bf16 training well with my optim?

@bghira
Copy link
Contributor Author

bghira commented Apr 6, 2024

more than 790 steps without issue

@bghira
Copy link
Contributor Author

bghira commented Apr 27, 2024

so i've been doing more extensive experimentation with mps and on any version of pytorch the experience is currently just very lacking

  • 2.2 has no bf16 support, has other issues on MacOS 14.x
  • 2.3 has bf16 support, has one backported fix for the torch.where() bug that produces NaNs
  • 2.4 has bf16, all the latest fixes

but despite that, something is still ruining the results for calculations on MPS.

you can see this by trying to do inference on especially Stable Diffusion v1.5 or DeepFloyd using the MPS device.

  • CPU works, produces viable results
  • MPS only works in float32 (crashes in float16 or bfloat16) and looks terrible

i was able to train v-prediction/min-snr Dreambooth for Terminus using the MPS device in pure bf16 but it never looked quite accurate. but still, the results weren't completely awful:

image

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

4 participants