From addac817be9bc36f05fffc675201aa58628fee6a Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 2 Apr 2024 08:45:06 -0600 Subject: [PATCH] 7529 do not disable autocast for cuda devices (#7530) * 7529 do not disable autocast for cuda devices * Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue * add autocast fix to other training examples * disable native_amp for dreambooth (sdxl) * disable native_amp for pix2pix (sdxl) * remove tests from remaining files * disable native_amp on huggingface accelerator for every training example that uses it * convert more usages of autocast to nullcontext, make style fixes * make style fixes * style. * Empty-Commit --------- Co-authored-by: bghira Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_sd15_advanced.py | 8 +++++- .../train_dreambooth_lora_sdxl_advanced.py | 13 +++++----- examples/amused/train_amused.py | 3 +++ .../train_lcm_distill_lora_sd_wds.py | 16 +++++++++--- .../train_lcm_distill_lora_sdxl.py | 8 +++++- .../train_lcm_distill_lora_sdxl_wds.py | 21 +++++++++++++--- .../train_lcm_distill_sd_wds.py | 22 +++++++++++++--- .../train_lcm_distill_sdxl_wds.py | 22 +++++++++++++--- examples/controlnet/train_controlnet.py | 4 +++ examples/controlnet/train_controlnet_sdxl.py | 17 +++++++------ .../train_custom_diffusion.py | 4 +++ examples/dreambooth/train_dreambooth.py | 4 +++ examples/dreambooth/train_dreambooth_lora.py | 4 +++ .../dreambooth/train_dreambooth_lora_sdxl.py | 21 ++++++++-------- .../train_instruct_pix2pix.py | 14 ++++++++--- .../train_instruct_pix2pix_sdxl.py | 25 +++++++++---------- .../train_text_to_image_decoder.py | 4 +++ .../train_text_to_image_lora_decoder.py | 5 ++++ .../train_text_to_image_lora_prior.py | 5 ++++ .../train_text_to_image_prior.py | 4 +++ .../controlnet/train_controlnet_webdataset.py | 4 +++ .../diffusion_dpo/train_diffusion_dpo.py | 4 +++ .../diffusion_dpo/train_diffusion_dpo_sdxl.py | 4 +++ .../train_diffusion_orpo_sdxl_lora.py | 4 +++ .../train_diffusion_orpo_sdxl_lora_wds.py | 4 +++ .../train_instruct_pix2pix_lora.py | 14 ++++++++--- .../textual_inversion_bf16.py | 4 +++ .../lora/train_text_to_image_lora.py | 5 ++++ .../train_multi_subject_dreambooth.py | 4 +++ .../textual_inversion.py | 4 +++ .../text_to_image/train_text_to_image.py | 4 +++ .../textual_inversion/textual_inversion.py | 4 +++ .../train_unconditional.py | 4 +++ .../t2i_adapter/train_t2i_adapter_sdxl.py | 4 +++ examples/text_to_image/train_text_to_image.py | 12 ++++++++- .../text_to_image/train_text_to_image_lora.py | 20 +++++++++++++-- .../train_text_to_image_lora_sdxl.py | 17 +++++-------- .../text_to_image/train_text_to_image_sdxl.py | 23 +++++++++-------- .../textual_inversion/textual_inversion.py | 12 ++++++++- .../textual_inversion_sdxl.py | 4 +++ .../train_text_to_image_lora_prior.py | 4 +++ .../train_text_to_image_prior.py | 4 +++ .../pipeline_stable_diffusion_depth2img.py | 11 ++++++-- .../pipeline_stable_diffusion_xl.py | 8 ------ .../pipeline_stable_diffusion_xl_img2img.py | 8 ------ .../pipeline_stable_diffusion_xl_inpaint.py | 8 ------ ...ne_stable_diffusion_xl_instruct_pix2pix.py | 8 ------ 47 files changed, 312 insertions(+), 118 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 5ce94680aeb24..6cdf2e7b21ab7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -23,6 +23,7 @@ import re import shutil import warnings +from contextlib import nullcontext from pathlib import Path from typing import List, Optional @@ -1844,7 +1845,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ff272e3b902e9..21a84b77245a3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import contextlib import gc import hashlib import itertools @@ -26,6 +25,7 @@ import re import shutil import warnings +from contextlib import nullcontext from pathlib import Path from typing import List, Optional @@ -2192,13 +2192,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - inference_ctx = ( - contextlib.nullcontext() - if "playground" in args.pretrained_model_name_or_path - else torch.cuda.amp.autocast() - ) + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with inference_ctx: + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py index 33673b3f7eb77..3ec0503dfdfef 100644 --- a/examples/amused/train_amused.py +++ b/examples/amused/train_amused.py @@ -430,6 +430,9 @@ def main(args): log_with=args.report_to, project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 46470be865cd3..1e88cb67ee713 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -23,6 +23,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union @@ -238,6 +239,10 @@ def train_dataloader(self): def log_validation(vae, unet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionPipeline.from_pretrained( @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: target_noise_pred = unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a4052324c128f..9405c238f9375 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -22,6 +22,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index fc4da48fbc4cf..08d6b23d6deb4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -24,6 +24,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union @@ -256,6 +257,10 @@ def train_dataloader(self): def log_validation(vae, unet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionXLPipeline.from_pretrained( @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1353,7 +1358,12 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1416,7 +1426,12 @@ def compute_embeddings( # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): - with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 8908593b16d31..d873cb8deb58a 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -23,6 +23,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = target_unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 74d1c007f7f31..862777411ccc8 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -24,6 +24,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1355,7 +1361,12 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1417,7 +1428,12 @@ def compute_embeddings( # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = target_unet( x_prev.float(), timesteps, diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index a56e92de26610..3daca0e3f56bd 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -752,6 +752,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b60280523589c..62192521a3234 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import contextlib import functools import gc import logging @@ -22,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ) image_logs = [] - inference_ctx = ( - contextlib.nullcontext() - if (is_final_validation or torch.backends.mps.is_available()) - else torch.autocast("cuda") - ) + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, images = [] for _ in range(args.num_validation_images): - with inference_ctx: + with autocast_ctx: image = pipeline( prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator ).images[0] @@ -811,6 +810,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 35d3a59c72313..6858fed8b9946 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -676,6 +676,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c40758eebfe9c..a18c443e7d4de 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -821,6 +821,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9b43b30e0fe18..0d33a05589890 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -749,6 +749,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1da83ff731adf..f3e347cd6ac99 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -23,6 +23,7 @@ import random import shutil import warnings +from contextlib import nullcontext from pathlib import Path import numpy as np @@ -207,18 +208,12 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): - enable_autocast = False - if "playground" in args.pretrained_model_name_or_path: - enable_autocast = False + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: @@ -992,6 +987,10 @@ def main(args): kwargs_handlers=[kwargs], ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b40f22df59df7..9c7cf0847746f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -21,6 +21,7 @@ import math import os import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -404,6 +405,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.report_to == "wandb": @@ -943,9 +948,12 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): edited_images.append( pipeline( diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index aff279963a99d..1c0cdf04b2d2b 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -20,6 +20,7 @@ import os import shutil import warnings +from contextlib import nullcontext from pathlib import Path from urllib.parse import urlparse @@ -70,9 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} -def log_validation( - pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True -): +def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." @@ -91,7 +90,12 @@ def log_validation( else Image.open(image_url_or_path).convert("RGB") )(args.val_image_url_or_path) - with torch.autocast(accelerator.device.type, enabled=enable_autocast): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: edited_images = [] # Run inference for val_img_idx in range(args.num_validation_images): @@ -507,6 +511,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) # Make one log on every process with the configuration for debugging. @@ -983,13 +991,6 @@ def collate_fn(examples): if accelerator.is_main_process: accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args)) - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): - enable_autocast = False - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1202,7 +1203,6 @@ def collate_fn(examples): generator, global_step, is_final_validation=False, - enable_autocast=enable_autocast, ) if args.use_ema: @@ -1252,7 +1252,6 @@ def collate_fn(examples): generator, global_step, is_final_validation=True, - enable_autocast=enable_autocast, ) accelerator.end_training() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index a2a13398124ab..78f9b7f18b87d 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -458,6 +458,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index d6fce4937413f..eb8ae8cca060e 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -343,6 +343,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index b19af1f3e341f..e169cf92beb92 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -356,6 +356,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index bbc0960e0f485..bd95aed2939ce 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -459,6 +459,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 2b397d27d6a2c..615eb834ac24d 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -916,6 +916,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index 4bb6b894476a3..3cec037e2544f 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -484,6 +484,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 24d51658e3b0f..0297a06f5b2ce 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -526,6 +526,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py index 2fcddad8b63fa..cdc096190f087 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py @@ -516,6 +516,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py index fc4e458f90eaf..cd1ef265d23eb 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py @@ -623,6 +623,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py index 436706c8512d2..997d448fa281b 100644 --- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py +++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py @@ -21,6 +21,7 @@ import math import os import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -410,6 +411,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.report_to == "wandb": @@ -967,9 +972,12 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): edited_images.append( pipeline( diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index 3cfd728214901..ea4a0d255b680 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -378,6 +378,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 462c3bbd44cfc..cf00bf270057d 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -411,6 +411,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 53b70abd0115c..0f507b26d6a81 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -698,6 +698,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py index 5fab1b6e9cbce..57ad77477b0d4 100644 --- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py @@ -566,6 +566,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2045ef4197c18..ee61f033d34db 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -439,6 +439,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 5d774d591d9a5..e10564fa59efb 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -581,6 +581,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index cba2c5117bfee..9a00f7cc4a9ad 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -295,6 +295,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.logger == "tensorboard": if not is_tensorboard_available(): raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 973915c96da4c..50735ef044a6f 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -799,6 +799,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4138d1b46329f..84f4c6514cfd6 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -20,6 +20,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight images = [] for i in range(len(args.validation_prompts)): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] images.append(image) @@ -523,6 +529,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d8204360a0e7b..76bece8df056c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -21,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import datasets @@ -426,6 +427,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") @@ -928,7 +934,12 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): images.append( pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] @@ -998,7 +1009,12 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): images.append( pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 26e05f66358b0..b592b3f6811be 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -21,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import datasets @@ -1014,13 +1015,6 @@ def collate_fn(examples): if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune", config=vars(args)) - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): - enable_autocast = False - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1246,11 +1240,12 @@ def compute_time_ids(original_size, crops_coords_top_left): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index c141f5bdd7063..a341db3aa3d93 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -23,6 +23,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -603,6 +604,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") @@ -986,12 +991,10 @@ def unwrap_model(model): model = model._orig_mod if is_compiled_module(model) else model return model - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): - enable_autocast = False + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1226,10 +1229,7 @@ def compute_time_ids(original_size, crops_coords_top_left): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] for _ in range(args.num_validation_images) @@ -1284,7 +1284,8 @@ def compute_time_ids(original_size, crops_coords_top_left): if args.validation_prompt and args.num_validation_images > 0: pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - with torch.autocast(accelerator.device.type, enabled=enable_autocast): + + with autocast_ctx: images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 0f4bb7604f3c4..4922789862b54 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -20,6 +20,7 @@ import random import shutil import warnings +from contextlib import nullcontext from pathlib import Path import numpy as np @@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] images.append(image) @@ -600,6 +606,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 460acf9f80096..c24a4c4f4855e 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -605,6 +605,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index e7d5898e11185..76eaf6423960b 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -460,6 +460,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 7aaebed3b0852..49cc5d26072d1 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -458,6 +458,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 72b438cd33257..1f822971568fe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -548,8 +548,15 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui pixel_values = pixel_values.to(device=device) # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. # So we use `torch.autocast` here for half precision inference. - context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() - with context_manger: + if torch.backends.mps.is_available(): + autocast_ctx = contextlib.nullcontext() + logger.warning( + "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." + ) + else: + autocast_ctx = torch.autocast(device.type, dtype=dtype) + + with autocast_ctx: depth_map = self.depth_estimator(pixel_values).predicted_depth else: depth_map = depth_map.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8d1646e4d8874..efad9cf6cc1bd 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1199,10 +1199,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if callback_on_step_end is not None: callback_kwargs = {} @@ -1241,10 +1237,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index af9da5073e067..9f6227bc914a2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1376,10 +1376,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if callback_on_step_end is not None: callback_kwargs = {} @@ -1418,10 +1414,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c9a72ccda9855..378f53ab08444 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1726,10 +1726,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if num_channels_unet == 4: init_latents_proper = image_latents @@ -1785,10 +1781,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 9aedb8667587d..31dc5acc89954 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -924,10 +924,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -950,10 +946,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None