Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apple mps: training support for SDXL (ControlNet, LoRA, Dreambooth, T2I) #7447

Merged
merged 6 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand Down Expand Up @@ -792,6 +796,12 @@ def main(args):

logging_dir = Path(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
Expand Down
39 changes: 30 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import gc
import itertools
import json
Expand Down Expand Up @@ -208,11 +207,18 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
inference_ctx = (
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
)
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False

with inference_ctx:
with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why do we disable autocast in this script but contextlib.nullcontext in others?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want to peel the can of worms back further, we also need to use

            if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
                context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)

for deepspeed support.

it's starting to feel like we need an autocast manager wrapper in train utils

):
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand All @@ -230,7 +236,8 @@ def log_validation(
)

del pipeline
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return images

Expand Down Expand Up @@ -967,6 +974,12 @@ def main(args):
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
Expand Down Expand Up @@ -1009,7 +1022,8 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -1134,6 +1148,12 @@ def main(args):
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)

Expand Down Expand Up @@ -1278,7 +1298,7 @@ def load_model_hook(models, input_dir):

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
Expand Down Expand Up @@ -1455,7 +1475,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down
25 changes: 18 additions & 7 deletions examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@


def log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
Expand All @@ -96,7 +91,7 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)

with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
Expand Down Expand Up @@ -497,6 +492,13 @@ def main():
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand Down Expand Up @@ -981,6 +983,13 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1193,6 +1202,7 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=False,
enable_autocast=enable_autocast,
)

if args.use_ema:
Expand Down Expand Up @@ -1242,6 +1252,7 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=True,
enable_autocast=enable_autocast,
)

accelerator.end_training()
Expand Down
18 changes: 17 additions & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ def main(args):

logging_dir = Path(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
Expand Down Expand Up @@ -973,6 +979,13 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1199,7 +1212,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
20 changes: 18 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ def main(args):

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand Down Expand Up @@ -980,6 +986,13 @@ def unwrap_model(model):
model = model._orig_mod if is_compiled_module(model) else model
return model

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1213,7 +1226,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
images = [
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
for _ in range(args.num_validation_images)
Expand Down Expand Up @@ -1268,7 +1284,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
with torch.cuda.amp.autocast():
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Loading