Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Examples] Multiple enhancements to the ControlNet training scripts #7096

Merged
merged 10 commits into from
Feb 27, 2024
12 changes: 11 additions & 1 deletion examples/controlnet/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,14 @@ image.save("./output.png")

### Specifying a better VAE

SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:

```diff
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
+ vae=vae,
)
39 changes: 34 additions & 5 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import gc
import logging
import math
import os
Expand Down Expand Up @@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols):
return grid


def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
def log_validation(
vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
):
logger.info("Running validation... ")

controlnet = accelerator.unwrap_model(controlnet)
if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)

pipeline = StableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
Expand Down Expand Up @@ -118,14 +125,15 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")

images = []

for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
with inference_ctx:
image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand All @@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)

tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
Expand Down Expand Up @@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)

tracker.log({"validation": formatted_images})
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")

del pipeline
gc.collect()
torch.cuda.empty_cache()

return image_logs


Expand All @@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
Expand Down Expand Up @@ -1131,6 +1144,22 @@ def load_model_hook(models, input_dir):
controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir)

# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)

if args.push_to_hub:
save_model_card(
repo_id,
Expand Down
72 changes: 57 additions & 15 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
import functools
import gc
import logging
Expand Down Expand Up @@ -65,20 +66,38 @@
logger = get_logger(__name__)


def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ")

controlnet = accelerator.unwrap_model(controlnet)
if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is not None:
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
else:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
)

pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)

pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down Expand Up @@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
)

image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand All @@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
images = []

for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
with inference_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand All @@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)

tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
Expand Down Expand Up @@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)

tracker.log({"validation": formatted_images})
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")

Expand Down Expand Up @@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path(
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
Expand Down Expand Up @@ -1228,7 +1249,13 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer

if args.validation_prompt is not None and global_step % args.validation_steps == 0:
image_logs = log_validation(
vae, unet, controlnet, args, accelerator, weight_dtype, global_step
vae=vae,
unet=unet,
controlnet=controlnet,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
)

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand All @@ -1244,6 +1271,21 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir)

# Run a final round of validation.
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=None,
unet=None,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)

if args.push_to_hub:
save_model_card(
repo_id,
Expand Down
Loading