Skip to content

Commit

Permalink
7529 do not disable autocast for cuda devices (huggingface#7530)
Browse files Browse the repository at this point in the history
* 7529 do not disable autocast for cuda devices

* Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue

* add autocast fix to other training examples

* disable native_amp for dreambooth (sdxl)

* disable native_amp for pix2pix (sdxl)

* remove tests from remaining files

* disable native_amp on huggingface accelerator for every training example that uses it

* convert more usages of autocast to nullcontext, make style fixes

* make style fixes

* style.

* Empty-Commit

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
3 people authored and noskill committed Apr 5, 2024
1 parent fb86051 commit addac81
Show file tree
Hide file tree
Showing 47 changed files with 312 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -1844,7 +1845,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import gc
import hashlib
import itertools
Expand All @@ -26,6 +25,7 @@
import re
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -2192,13 +2192,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with inference_ctx:
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
3 changes: 3 additions & 0 deletions examples/amused/train_amused.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def main(args):
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
Expand Down
16 changes: 13 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -238,6 +239,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -256,6 +257,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
Expand Down Expand Up @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1353,7 +1358,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1416,7 +1426,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Union

Expand Down Expand Up @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1355,7 +1361,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1417,7 +1428,12 @@ def compute_embeddings(

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
4 changes: 4 additions & 0 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down
17 changes: 10 additions & 7 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import functools
import gc
import logging
import math
import os
import random
import shutil
from contextlib import nullcontext
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
if is_final_validation or torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand All @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []

for _ in range(args.num_validation_images):
with inference_ctx:
with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand Down Expand Up @@ -811,6 +810,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down
4 changes: 4 additions & 0 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
4 changes: 4 additions & 0 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
4 changes: 4 additions & 0 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Expand Down
Loading

0 comments on commit addac81

Please sign in to comment.