Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) #7447

Merged
merged 6 commits into from
Mar 28, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,20 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)

with inference_ctx:
enable_autocast = True
if (
not torch.backends.mps.is_available()
or (accelerator.mixed_precision == "fp16"
or accelerator.mixed_precision == "bf16")
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False

with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=enable_autocast,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why do we disable autocast in this script but contextlib.nullcontext in others?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to peel the can of worms back further, we also need to use

            if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
                context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)

for deepspeed support.

it's starting to feel like we need an autocast manager wrapper in train utils

):
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand All @@ -227,7 +236,8 @@ def log_validation(
)

del pipeline
torch.cuda.empty_cache()
if torch.backends.cuda.is_available():
torch.cuda.empty_cache()

return images

Expand Down Expand Up @@ -959,6 +969,10 @@ def main(args):
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
args.mixed_precision = "fp16"

logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
Expand Down Expand Up @@ -1001,7 +1015,8 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
has_supported_fp16_accelerator = accelerator.device.type == "cuda" or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -1036,7 +1051,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
if accelerator.device.type == "cuda":
torch.cuda.empty_cache()

# Handle the repository creation
Expand Down Expand Up @@ -1126,6 +1141,10 @@ def main(args):
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
weight_dtype = torch.float16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)

Expand Down Expand Up @@ -1270,7 +1289,7 @@ def load_model_hook(models, input_dir):

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
if args.allow_tf32 and accelerator.device.type == "cuda":
torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
Expand Down Expand Up @@ -1447,7 +1466,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
if accelerator.device.type == "cuda":
torch.cuda.empty_cache()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down
Loading