Skip to content

Commit

Permalink
sdxl: support training lora, dreambooth, t2i, pix2pix, and controlnet…
Browse files Browse the repository at this point in the history
… on apple mps
  • Loading branch information
bghira committed Mar 25, 2024
1 parent b242f16 commit a1a777f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
10 changes: 9 additions & 1 deletion examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand Down Expand Up @@ -792,6 +796,10 @@ def main(args):

logging_dir = Path(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
args.mixed_precision = "fp16"

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
Expand Down
9 changes: 3 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import gc
import itertools
import json
Expand Down Expand Up @@ -206,11 +205,9 @@ def log_validation(
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
enable_autocast = True
if (
not torch.backends.mps.is_available()
or (accelerator.mixed_precision == "fp16"
or accelerator.mixed_precision == "bf16")
):
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False
if "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
Expand Down
23 changes: 16 additions & 7 deletions examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@


def log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
Expand All @@ -96,7 +91,7 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)

with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=enable_autocast):
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
Expand Down Expand Up @@ -497,6 +492,11 @@ def main():
),
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
args.mixed_precision = "fp16"

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand Down Expand Up @@ -981,6 +981,13 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1193,6 +1200,7 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=False,
enable_autocast=enable_autocast,
)

if args.use_ema:
Expand Down Expand Up @@ -1242,6 +1250,7 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=True,
enable_autocast=enable_autocast,
)

accelerator.end_training()
Expand Down
16 changes: 15 additions & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ def main(args):

logging_dir = Path(args.output_dir, args.logging_dir)

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
args.mixed_precision = "fp16"

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
Expand Down Expand Up @@ -973,6 +977,13 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1199,7 +1210,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=enable_autocast,
):
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
18 changes: 16 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,10 @@ def main(args):

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

if torch.backends.mps.is_available():
# due to pytorch#99272, MPS does not yet support bfloat16.
args.mixed_precision = "fp16"

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand Down Expand Up @@ -980,6 +984,13 @@ def unwrap_model(model):
model = model._orig_mod if is_compiled_module(model) else model
return model

# Some configurations require autocast to be disabled.
enable_autocast = True
if torch.backends.mps.is_available() or (
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
):
enable_autocast = False

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1213,7 +1224,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
with torch.autocast(
str(accelerator.device).replace(":0", ""),
enabled=enable_autocast,
):
images = [
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
for _ in range(args.num_validation_images)
Expand Down Expand Up @@ -1268,7 +1282,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
with torch.cuda.amp.autocast():
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=enable_autocast):
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down

0 comments on commit a1a777f

Please sign in to comment.