Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7529 do not disable autocast for cuda devices #7530

Merged
merged 15 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available

from contextlib import nullcontext

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
Expand Down Expand Up @@ -1844,7 +1844,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from pathlib import Path
from typing import List, Optional
from contextlib import nullcontext

import numpy as np
import torch
Expand Down Expand Up @@ -2192,13 +2193,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with inference_ctx:
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
16 changes: 13 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -238,6 +239,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import random
import shutil
from pathlib import Path
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -256,6 +257,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
Expand Down Expand Up @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1353,7 +1358,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1416,7 +1426,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1355,7 +1361,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1417,7 +1428,12 @@ def compute_embeddings(

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
13 changes: 6 additions & 7 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
from contextlib import nullcontext
import functools
import gc
import logging
Expand Down Expand Up @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
if is_final_validation or torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand All @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []

for _ in range(args.num_validation_images):
with inference_ctx:
with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand Down
17 changes: 6 additions & 11 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from contextlib import nullcontext

import diffusers
from diffusers import (
Expand Down Expand Up @@ -207,18 +208,12 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
Loading
Loading