From 688051c4039163765e57cfbdbe1ec851582d6852 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 30 Mar 2024 10:50:55 -0600 Subject: [PATCH 01/11] 7529 do not disable autocast for cuda devices --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 6 +----- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 4 +--- examples/text_to_image/train_text_to_image_lora_sdxl.py | 4 +--- examples/text_to_image/train_text_to_image_sdxl.py | 4 +--- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1da83ff731ad..14de66f9e8c8 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -208,11 +208,7 @@ def log_validation( # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): - enable_autocast = False - if "playground" in args.pretrained_model_name_or_path: + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: enable_autocast = False with torch.autocast( diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index aff279963a99..21726e8961e4 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -985,9 +985,7 @@ def collate_fn(examples): # Some configurations require autocast to be disabled. enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): + if torch.backends.mps.is_available(): enable_autocast = False # Train! diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index c9860b744c03..2ddd1824bd72 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -981,9 +981,7 @@ def collate_fn(examples): # Some configurations require autocast to be disabled. enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): + if torch.backends.mps.is_available(): enable_autocast = False # Train! diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index c141f5bdd706..851503f07429 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -988,9 +988,7 @@ def unwrap_model(model): # Some configurations require autocast to be disabled. enable_autocast = True - if torch.backends.mps.is_available() or ( - accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16" - ): + if torch.backends.mps.is_available(): enable_autocast = False # Train! From 20a7c8fc3bc1dc4c93a47042e2eda61e9140bf28 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 31 Mar 2024 08:27:40 -0600 Subject: [PATCH 02/11] Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 4 ---- .../pipeline_stable_diffusion_xl_img2img.py | 4 ---- .../pipeline_stable_diffusion_xl_inpaint.py | 4 ---- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 ---- 4 files changed, 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 8d1646e4d887..a5f085d13327 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1241,10 +1241,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index af9da5073e06..78f54ee87247 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1418,10 +1418,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c9a72ccda985..57e8ef2663e8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1785,10 +1785,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 9aedb8667587..0f48b1084395 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -950,10 +950,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # unscale/denormalize the latents # denormalize with the mean and std if available and not None From ad3eb800385d0247faf0c59d9aed731f780a4318 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 31 Mar 2024 13:14:43 -0600 Subject: [PATCH 03/11] add autocast fix to other training examples --- .../train_dreambooth_lora_sd15_advanced.py | 9 ++++++-- .../train_dreambooth_lora_sdxl_advanced.py | 12 +++++----- .../train_lcm_distill_lora_sd_wds.py | 16 +++++++++++--- .../train_lcm_distill_lora_sdxl.py | 8 ++++++- .../train_lcm_distill_lora_sdxl_wds.py | 21 +++++++++++++++--- .../train_lcm_distill_sd_wds.py | 22 ++++++++++++++++--- .../train_lcm_distill_sdxl_wds.py | 22 ++++++++++++++++--- examples/controlnet/train_controlnet_sdxl.py | 13 +++++------ .../dreambooth/train_dreambooth_lora_sdxl.py | 11 +++++----- .../train_instruct_pix2pix_sdxl.py | 17 +++++++------- .../train_text_to_image_lora_sdxl.py | 15 +++++-------- .../text_to_image/train_text_to_image_sdxl.py | 17 +++++++------- 12 files changed, 122 insertions(+), 61 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 5ce94680aeb2..200edff94df8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -67,7 +67,7 @@ is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available - +from contextlib import nullcontext # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.28.0.dev0") @@ -1844,7 +1844,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ff272e3b902e..d3147f1989e3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -28,6 +28,7 @@ import warnings from pathlib import Path from typing import List, Optional +from contextlib import nullcontext import numpy as np import torch @@ -2192,13 +2193,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - inference_ctx = ( - contextlib.nullcontext() - if "playground" in args.pretrained_model_name_or_path - else torch.cuda.amp.autocast() - ) + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with inference_ctx: + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 46470be865cd..a650ecbab4bd 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -25,6 +25,7 @@ import shutil from pathlib import Path from typing import List, Union +from contextlib import nullcontext import accelerate import numpy as np @@ -238,6 +239,10 @@ def train_dataloader(self): def log_validation(vae, unet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionPipeline.from_pretrained( @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: target_noise_pred = unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index a4052324c128..4c4191b91cb4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -23,6 +23,7 @@ import random import shutil from pathlib import Path +from contextlib import nullcontext import accelerate import numpy as np @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index fc4da48fbc4c..21750969142b 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -26,6 +26,7 @@ import shutil from pathlib import Path from typing import List, Union +from contextlib import nullcontext import accelerate import numpy as np @@ -256,6 +257,10 @@ def train_dataloader(self): def log_validation(vae, unet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) unet = accelerator.unwrap_model(unet) pipeline = StableDiffusionXLPipeline.from_pretrained( @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step): for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda", dtype=weight_dtype): + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1353,7 +1358,12 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1416,7 +1426,12 @@ def compute_embeddings( # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): - with torch.autocast("cuda", enabled=True, dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 8908593b16d3..00a410806163 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -25,6 +25,7 @@ import shutil from pathlib import Path from typing import List, Union +from contextlib import nullcontext import accelerate import numpy as np @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = target_unet( x_prev.float(), timesteps, diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 74d1c007f7f3..0530bac179fc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -26,6 +26,7 @@ import shutil from pathlib import Path from typing import List, Union +from contextlib import nullcontext import accelerate import numpy as np @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe for _, prompt in enumerate(validation_prompts): images = [] - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: images = pipeline( prompt=prompt, num_inference_steps=4, @@ -1355,7 +1361,12 @@ def compute_embeddings( # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. with torch.no_grad(): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), @@ -1417,7 +1428,12 @@ def compute_embeddings( # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): - with torch.autocast("cuda", dtype=weight_dtype): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype) + + with autocast_ctx: target_noise_pred = target_unet( x_prev.float(), timesteps, diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b60280523589..4e8f94a0f689 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and import argparse -import contextlib +from contextlib import nullcontext import functools import gc import logging @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, ) image_logs = [] - inference_ctx = ( - contextlib.nullcontext() - if (is_final_validation or torch.backends.mps.is_available()) - else torch.autocast("cuda") - ) + if is_final_validation or torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, images = [] for _ in range(args.num_validation_images): - with inference_ctx: + with autocast_ctx: image = pipeline( prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator ).images[0] diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 14de66f9e8c8..615c81437cd2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -46,6 +46,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from contextlib import nullcontext import diffusers from diffusers import ( @@ -207,14 +208,12 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 - enable_autocast = True if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: - enable_autocast = False + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 21726e8961e4..816404af4b28 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -42,6 +42,7 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from contextlib import nullcontext import diffusers from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel @@ -71,7 +72,7 @@ def log_validation( - pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True + pipeline, args, accelerator, generator, global_step, is_final_validation=False ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -91,7 +92,12 @@ def log_validation( else Image.open(image_url_or_path).convert("RGB") )(args.val_image_url_or_path) - with torch.autocast(accelerator.device.type, enabled=enable_autocast): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: edited_images = [] # Run inference for val_img_idx in range(args.num_validation_images): @@ -983,11 +989,6 @@ def collate_fn(examples): if accelerator.is_main_process: accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args)) - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available(): - enable_autocast = False - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1200,7 +1201,6 @@ def collate_fn(examples): generator, global_step, is_final_validation=False, - enable_autocast=enable_autocast, ) if args.use_ema: @@ -1250,7 +1250,6 @@ def collate_fn(examples): generator, global_step, is_final_validation=True, - enable_autocast=enable_autocast, ) accelerator.end_training() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 2ddd1824bd72..ed56d04efd0d 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -41,6 +41,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from contextlib import nullcontext import diffusers from diffusers import ( @@ -979,11 +980,6 @@ def collate_fn(examples): if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune", config=vars(args)) - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available(): - enable_autocast = False - # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1209,11 +1205,12 @@ def compute_time_ids(original_size, crops_coords_top_left): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 851503f07429..363fc208e21c 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -42,6 +42,7 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from contextlib import nullcontext import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel @@ -986,10 +987,10 @@ def unwrap_model(model): model = model._orig_mod if is_compiled_module(model) else model return model - # Some configurations require autocast to be disabled. - enable_autocast = True - if torch.backends.mps.is_available(): - enable_autocast = False + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1224,10 +1225,7 @@ def compute_time_ids(original_size, crops_coords_top_left): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt} - with torch.autocast( - accelerator.device.type, - enabled=enable_autocast, - ): + with autocast_ctx: images = [ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] for _ in range(args.num_validation_images) @@ -1282,7 +1280,8 @@ def compute_time_ids(original_size, crops_coords_top_left): if args.validation_prompt and args.num_validation_images > 0: pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - with torch.autocast(accelerator.device.type, enabled=enable_autocast): + + with autocast_ctx: images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] for _ in range(args.num_validation_images) From 381f73da8537a65698b45e526b71392c3b778a67 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 05:53:21 -0600 Subject: [PATCH 04/11] disable native_amp for dreambooth (sdxl) --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 615c81437cd2..c048a895a395 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -987,6 +987,10 @@ def main(args): kwargs_handlers=[kwargs], ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") From 59b9b41b5df2509cf468aa4aa6e60bc7ac7c3c2a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 06:40:46 -0600 Subject: [PATCH 05/11] disable native_amp for pix2pix (sdxl) --- examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 816404af4b28..aeaab50edeb5 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -513,6 +513,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) # Make one log on every process with the configuration for debugging. From 35cdecdda52cc66482af9222eac654bb5373f770 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 08:43:26 -0600 Subject: [PATCH 06/11] remove tests from remaining files --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 4 ---- .../pipeline_stable_diffusion_xl_img2img.py | 4 ---- .../pipeline_stable_diffusion_xl_inpaint.py | 4 ---- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 4 ---- 4 files changed, 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a5f085d13327..efad9cf6cc1b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1199,10 +1199,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 78f54ee87247..9f6227bc914a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1376,10 +1376,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 57e8ef2663e8..378f53ab0844 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1726,10 +1726,6 @@ def denoising_value_valid(dnv): if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) if num_channels_unet == 4: init_latents_proper = image_latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 0f48b1084395..31dc5acc8995 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -924,10 +924,6 @@ def __call__( if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - else: - raise ValueError( - "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/." - ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From 561bf2d5e4d03538b99838633b25102c7506ce51 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 08:47:54 -0600 Subject: [PATCH 07/11] disable native_amp on huggingface accelerator for every training example that uses it --- examples/amused/train_amused.py | 3 +++ examples/controlnet/train_controlnet.py | 4 ++++ examples/controlnet/train_controlnet_sdxl.py | 4 ++++ examples/custom_diffusion/train_custom_diffusion.py | 4 ++++ examples/dreambooth/train_dreambooth.py | 4 ++++ examples/dreambooth/train_dreambooth_lora.py | 4 ++++ examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++++ .../text_to_image/train_text_to_image_decoder.py | 4 ++++ .../text_to_image/train_text_to_image_lora_decoder.py | 5 +++++ .../text_to_image/train_text_to_image_lora_prior.py | 5 +++++ .../kandinsky2_2/text_to_image/train_text_to_image_prior.py | 4 ++++ .../controlnet/train_controlnet_webdataset.py | 4 ++++ .../research_projects/diffusion_dpo/train_diffusion_dpo.py | 4 ++++ .../diffusion_dpo/train_diffusion_dpo_sdxl.py | 4 ++++ .../diffusion_orpo/train_diffusion_orpo_sdxl_lora.py | 4 ++++ .../diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py | 4 ++++ .../instructpix2pix_lora/train_instruct_pix2pix_lora.py | 4 ++++ .../intel_opts/textual_inversion/textual_inversion_bf16.py | 4 ++++ examples/research_projects/lora/train_text_to_image_lora.py | 5 +++++ .../train_multi_subject_dreambooth.py | 4 ++++ .../multi_token_textual_inversion/textual_inversion.py | 4 ++++ .../onnxruntime/text_to_image/train_text_to_image.py | 4 ++++ .../onnxruntime/textual_inversion/textual_inversion.py | 4 ++++ .../unconditional_image_generation/train_unconditional.py | 4 ++++ examples/t2i_adapter/train_t2i_adapter_sdxl.py | 4 ++++ examples/text_to_image/train_text_to_image.py | 4 ++++ examples/text_to_image/train_text_to_image_lora.py | 5 +++++ examples/text_to_image/train_text_to_image_sdxl.py | 4 ++++ examples/textual_inversion/textual_inversion.py | 4 ++++ examples/textual_inversion/textual_inversion_sdxl.py | 4 ++++ .../text_to_image/train_text_to_image_lora_prior.py | 4 ++++ .../wuerstchen/text_to_image/train_text_to_image_prior.py | 4 ++++ 32 files changed, 131 insertions(+) diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py index 33673b3f7eb7..3ec0503dfdfe 100644 --- a/examples/amused/train_amused.py +++ b/examples/amused/train_amused.py @@ -430,6 +430,9 @@ def main(args): log_with=args.report_to, project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index a56e92de2661..3daca0e3f56b 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -752,6 +752,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 4e8f94a0f689..ac3d52212be4 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -810,6 +810,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 35d3a59c7231..6858fed8b994 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -676,6 +676,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c40758eebfe9..a18c443e7d4d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -821,6 +821,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9b43b30e0fe1..0d33a0558989 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -749,6 +749,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b40f22df59df..b2cc36171dc0 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -404,6 +404,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.report_to == "wandb": diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index a2a13398124a..78f9b7f18b87 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -458,6 +458,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index d6fce4937413..eb8ae8cca060 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -343,6 +343,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index b19af1f3e341..e169cf92beb9 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -356,6 +356,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index bbc0960e0f48..bd95aed2939c 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -459,6 +459,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py index 2b397d27d6a2..615eb834ac24 100644 --- a/examples/research_projects/controlnet/train_controlnet_webdataset.py +++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py @@ -916,6 +916,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index 4bb6b894476a..3cec037e2544 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -484,6 +484,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 24d51658e3b0..0297a06f5b2c 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -526,6 +526,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py index 2fcddad8b63f..cdc096190f08 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py @@ -516,6 +516,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py index fc4e458f90ea..cd1ef265d23e 100644 --- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py +++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py @@ -623,6 +623,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py index 436706c8512d..5ce1cf2d9e3c 100644 --- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py +++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py @@ -410,6 +410,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.report_to == "wandb": diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index 3cfd72821490..ea4a0d255b68 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -378,6 +378,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 462c3bbd44cf..cf00bf270057 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -411,6 +411,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 53b70abd0115..0f507b26d6a8 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -698,6 +698,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py index 5fab1b6e9cbc..57ad77477b0d 100644 --- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py @@ -566,6 +566,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2045ef4197c1..ee61f033d34d 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -439,6 +439,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 5d774d591d9a..e10564fa59ef 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -581,6 +581,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index cba2c5117bfe..9a00f7cc4a9a 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -295,6 +295,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.logger == "tensorboard": if not is_tensorboard_available(): raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 973915c96da4..50735ef044a6 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -799,6 +799,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4138d1b46329..c6dd7b11c38b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -523,6 +523,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 71b99f1588c3..ab5587469470 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -408,6 +408,11 @@ def main(): log_with=args.report_to, project_config=accelerator_project_config, ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 363fc208e21c..660f3e4698d6 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -604,6 +604,10 @@ def main(args): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 0f4bb7604f3c..3760fb81b303 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -600,6 +600,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 460acf9f8009..c24a4c4f4855 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -605,6 +605,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index e7d5898e1118..76eaf6423960 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -460,6 +460,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 7aaebed3b085..49cc5d26072d 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -458,6 +458,10 @@ def main(): project_config=accelerator_project_config, ) + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", From 55912ed0968bea29124e4fa0e28b540a6817f574 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 08:54:42 -0600 Subject: [PATCH 08/11] convert more usages of autocast to nullcontext, make style fixes --- .../train_dreambooth_lora_sd15_advanced.py | 9 +++++---- .../train_dreambooth_lora_sdxl_advanced.py | 9 ++++----- .../train_lcm_distill_lora_sd_wds.py | 2 +- .../train_lcm_distill_lora_sdxl.py | 2 +- .../train_lcm_distill_lora_sdxl_wds.py | 2 +- .../train_lcm_distill_sd_wds.py | 2 +- .../train_lcm_distill_sdxl_wds.py | 4 ++-- examples/controlnet/train_controlnet_sdxl.py | 2 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 2 +- .../instruct_pix2pix/train_instruct_pix2pix.py | 10 +++++++--- .../train_instruct_pix2pix_sdxl.py | 6 ++---- .../train_instruct_pix2pix_lora.py | 10 +++++++--- .../textual_inversion.py | 6 +++--- .../textual_inversion/textual_inversion.py | 6 +++--- examples/text_to_image/train_text_to_image.py | 8 +++++++- .../text_to_image/train_text_to_image_lora.py | 15 +++++++++++++-- .../train_text_to_image_lora_sdxl.py | 2 +- .../text_to_image/train_text_to_image_sdxl.py | 2 +- .../textual_inversion/textual_inversion.py | 14 ++++++++++---- .../textual_inversion_sdxl.py | 6 +++--- src/diffusers/loaders/lora_conversion_utils.py | 18 +++++++++--------- .../pipeline_stable_diffusion_depth2img.py | 11 +++++++++-- 22 files changed, 92 insertions(+), 56 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 200edff94df8..4c6ab506fe91 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -23,6 +23,7 @@ import re import shutil import warnings +from contextlib import nullcontext from pathlib import Path from typing import List, Optional @@ -67,7 +68,7 @@ is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available -from contextlib import nullcontext + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.28.0.dev0") @@ -743,9 +744,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"original_embeddings_{idx}"] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + ) self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index d3147f1989e3..d6a63f91939d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import contextlib import gc import hashlib import itertools @@ -26,9 +25,9 @@ import re import shutil import warnings +from contextlib import nullcontext from pathlib import Path from typing import List, Optional -from contextlib import nullcontext import numpy as np import torch @@ -777,9 +776,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[ - f"original_embeddings_{idx}" - ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"original_embeddings_{idx}"] = ( + text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + ) self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index a650ecbab4bd..1e88cb67ee71 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -23,9 +23,9 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union -from contextlib import nullcontext import accelerate import numpy as np diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 4c4191b91cb4..9405c238f937 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -22,8 +22,8 @@ import os import random import shutil -from pathlib import Path from contextlib import nullcontext +from pathlib import Path import accelerate import numpy as np diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 21750969142b..08d6b23d6deb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -24,9 +24,9 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union -from contextlib import nullcontext import accelerate import numpy as np diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 00a410806163..d873cb8deb58 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -23,9 +23,9 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union -from contextlib import nullcontext import accelerate import numpy as np diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 0530bac179fc..862777411ccc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -24,9 +24,9 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path from typing import List, Union -from contextlib import nullcontext import accelerate import numpy as np @@ -1366,7 +1366,7 @@ def compute_embeddings( else: autocast_ctx = torch.autocast(accelerator.device.type) - with autocast_ctx: + with autocast_ctx: # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index ac3d52212be4..62192521a323 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -from contextlib import nullcontext import functools import gc import logging @@ -22,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c048a895a395..f3e347cd6ac9 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -23,6 +23,7 @@ import random import shutil import warnings +from contextlib import nullcontext from pathlib import Path import numpy as np @@ -46,7 +47,6 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -from contextlib import nullcontext import diffusers from diffusers import ( diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b2cc36171dc0..9c7cf0847746 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -21,6 +21,7 @@ import math import os import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -947,9 +948,12 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): edited_images.append( pipeline( diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index aeaab50edeb5..1c0cdf04b2d2 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -20,6 +20,7 @@ import os import shutil import warnings +from contextlib import nullcontext from pathlib import Path from urllib.parse import urlparse @@ -42,7 +43,6 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -from contextlib import nullcontext import diffusers from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel @@ -71,9 +71,7 @@ TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} -def log_validation( - pipeline, args, accelerator, generator, global_step, is_final_validation=False -): +def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py index 5ce1cf2d9e3c..997d448fa281 100644 --- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py +++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py @@ -21,6 +21,7 @@ import math import os import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -971,9 +972,12 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast( - str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" - ): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): edited_images.append( pipeline( diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py index 57ad77477b0d..7aad64ecb1dd 100644 --- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py @@ -830,9 +830,9 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = get_mask(tokenizer, accelerator) with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index e10564fa59ef..5f0710e85319 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -886,9 +886,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c6dd7b11c38b..84f4c6514cfd 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -20,6 +20,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight images = [] for i in range(len(args.validation_prompts)): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] images.append(image) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ab5587469470..7164ac909cb2 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -21,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import datasets @@ -883,7 +884,12 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): images.append( pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] @@ -953,7 +959,12 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - with torch.cuda.amp.autocast(): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: for _ in range(args.num_validation_images): images.append( pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index ed56d04efd0d..0a6a70de2dc7 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -21,6 +21,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import datasets @@ -41,7 +42,6 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -from contextlib import nullcontext import diffusers from diffusers import ( diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 660f3e4698d6..a341db3aa3d9 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -23,6 +23,7 @@ import os import random import shutil +from contextlib import nullcontext from pathlib import Path import accelerate @@ -42,7 +43,6 @@ from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -from contextlib import nullcontext import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 3760fb81b303..3ae1e85723ee 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -20,6 +20,7 @@ import random import shutil import warnings +from contextlib import nullcontext from pathlib import Path import numpy as np @@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] images.append(image) @@ -904,9 +910,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index c24a4c4f4855..cc020499be8e 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -940,9 +940,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 11e3311a6402..e4877d495970 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -209,9 +209,9 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ if is_unet_dora_lora: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." - unet_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): if lora_name.startswith(("lora_te_", "lora_te1_")): @@ -249,13 +249,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) elif lora_name.startswith("lora_te2_"): - te2_state_dict[ - diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") - ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( + state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + ) # Rename the alphas so that they can be mapped appropriately. if lora_name_alpha in state_dict: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 72b438cd3325..1f822971568f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -548,8 +548,15 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui pixel_values = pixel_values.to(device=device) # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. # So we use `torch.autocast` here for half precision inference. - context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() - with context_manger: + if torch.backends.mps.is_available(): + autocast_ctx = contextlib.nullcontext() + logger.warning( + "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." + ) + else: + autocast_ctx = torch.autocast(device.type, dtype=dtype) + + with autocast_ctx: depth_map = self.depth_estimator(pixel_values).predicted_depth else: depth_map = depth_map.to(device=device, dtype=dtype) From 455bb6f5af14c7b24811b17b13f250a031b9f969 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 1 Apr 2024 09:00:21 -0600 Subject: [PATCH 09/11] make style fixes --- scripts/convert_svd_to_diffusers.py | 12 ++++++------ tests/models/autoencoders/test_models_vae.py | 6 ++---- tests/pipelines/amused/test_amused.py | 3 +-- tests/pipelines/amused/test_amused_img2img.py | 3 +-- tests/pipelines/amused/test_amused_inpaint.py | 3 +-- 5 files changed, 11 insertions(+), 16 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 3243ce294b26..e46410ccb3bd 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint( # TODO resnet time_mixer.mix_factor if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" - ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( + unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + ) if len(attentions): paths = renew_attention_paths(attentions) @@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint( ) if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[ - f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" - ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( + unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index b0c24b8d4315..0f45fa0890af 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -384,12 +384,10 @@ def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() @unittest.skip - def test_training(self): - ... + def test_training(self): ... @unittest.skip - def test_ema_training(self): - ... + def test_ema_training(self): ... class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase): diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index f03751e2f830..ed03fef2b0cd 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -125,8 +125,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py index efbca1f437a4..794f23792911 100644 --- a/tests/pipelines/amused/test_amused_img2img.py +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -129,8 +129,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index d397f8d81297..9c8b1a68b1e1 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -133,8 +133,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): - ... + def test_inference_batch_single_identical(self): ... @slow From 148358111e1d62072fd5627a8f989b76165020fe Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Apr 2024 18:32:18 +0530 Subject: [PATCH 10/11] style. --- .../train_dreambooth_lora_sd15_advanced.py | 6 +++--- .../train_dreambooth_lora_sdxl_advanced.py | 6 +++--- .../textual_inversion.py | 6 +++--- .../textual_inversion/textual_inversion.py | 6 +++--- .../textual_inversion/textual_inversion.py | 6 +++--- .../textual_inversion_sdxl.py | 6 +++--- scripts/convert_svd_to_diffusers.py | 12 ++++++------ src/diffusers/loaders/lora_conversion_utils.py | 18 +++++++++--------- tests/models/autoencoders/test_models_vae.py | 6 ++++-- tests/pipelines/amused/test_amused.py | 3 ++- tests/pipelines/amused/test_amused_img2img.py | 3 ++- tests/pipelines/amused/test_amused_inpaint.py | 3 ++- 12 files changed, 43 insertions(+), 38 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 4c6ab506fe91..6cdf2e7b21ab 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -744,9 +744,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[f"original_embeddings_{idx}"] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data.clone() - ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index d6a63f91939d..21a84b77245a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -776,9 +776,9 @@ def initialize_new_tokens(self, inserting_toks: List[str]): .to(dtype=self.dtype) * std_token_embedding ) - self.embeddings_settings[f"original_embeddings_{idx}"] = ( - text_encoder.text_model.embeddings.token_embedding.weight.data.clone() - ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding inu = torch.ones((len(tokenizer),), dtype=torch.bool) diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py index 7aad64ecb1dd..57ad77477b0d 100644 --- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py +++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py @@ -830,9 +830,9 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = get_mask(tokenizer, accelerator) with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 5f0710e85319..e10564fa59ef 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -886,9 +886,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 3ae1e85723ee..4922789862b5 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -910,9 +910,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index cc020499be8e..c24a4c4f4855 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -940,9 +940,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index e46410ccb3bd..3243ce294b26 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint( # TODO resnet time_mixer.mix_factor if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( - unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] - ) + new_checkpoint[ + f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] if len(attentions): paths = renew_attention_paths(attentions) @@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint( ) if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: - new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = ( - unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] - ) + new_checkpoint[ + f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e4877d495970..11e3311a6402 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -209,9 +209,9 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ if is_unet_dora_lora: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." - unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + unet_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): if lora_name.startswith(("lora_te_", "lora_te1_")): @@ -249,13 +249,13 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + te_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) elif lora_name.startswith("lora_te2_"): - te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = ( - state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - ) + te2_state_dict[ + diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") + ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Rename the alphas so that they can be mapped appropriately. if lora_name_alpha in state_dict: diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 0f45fa0890af..b0c24b8d4315 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -384,10 +384,12 @@ def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() @unittest.skip - def test_training(self): ... + def test_training(self): + ... @unittest.skip - def test_ema_training(self): ... + def test_ema_training(self): + ... class AutoencoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase): diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index ed03fef2b0cd..f03751e2f830 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -125,7 +125,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): ... + def test_inference_batch_single_identical(self): + ... @slow diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py index 794f23792911..efbca1f437a4 100644 --- a/tests/pipelines/amused/test_amused_img2img.py +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -129,7 +129,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): ... + def test_inference_batch_single_identical(self): + ... @slow diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py index 9c8b1a68b1e1..d397f8d81297 100644 --- a/tests/pipelines/amused/test_amused_inpaint.py +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -133,7 +133,8 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) @unittest.skip("aMUSEd does not support lists of generators") - def test_inference_batch_single_identical(self): ... + def test_inference_batch_single_identical(self): + ... @slow From d084ad53844f562067f88c5cfd83accd0561165c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 2 Apr 2024 18:42:14 +0530 Subject: [PATCH 11/11] Empty-Commit