Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7529 do not disable autocast for cuda devices #7530

Merged
merged 15 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -1844,7 +1845,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import gc
import hashlib
import itertools
Expand All @@ -26,6 +25,7 @@
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -2192,13 +2192,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with inference_ctx:
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
3 changes: 3 additions & 0 deletions examples/amused/train_amused.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def main(args):
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
Expand Down
16 changes: 13 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -238,6 +239,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -256,6 +257,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
Expand Down Expand Up @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1353,7 +1358,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1416,7 +1426,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1355,7 +1361,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1417,7 +1428,12 @@ def compute_embeddings(

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
4 changes: 4 additions & 0 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bghira possible to add a more descriptive comment here?

Copy link
Member

@sayakpaul sayakpaul Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I see this is not added to some scripts such as the advanced diffusion or consistency distillation scripts.

if torch.backends.mps.is_available():
accelerator.native_amp = False

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down
17 changes: 10 additions & 7 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import functools
import gc
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
if is_final_validation or torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand All @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []

for _ in range(args.num_validation_images):
with inference_ctx:
with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand Down Expand Up @@ -811,6 +810,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down
4 changes: 4 additions & 0 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
4 changes: 4 additions & 0 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
4 changes: 4 additions & 0 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
Loading
Loading