Skip to content

Commit

Permalink
SDXL Turbo support and example launch (#6473)
Browse files Browse the repository at this point in the history
* support and example launch for sdxl turbo

* White space fixes

* Trailing whitespace character

* ruff format

* fix guidance_scale and steps for turbo mode

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Radames Ajna <radamajna@gmail.com>
  • Loading branch information
3 people authored Mar 6, 2024
1 parent 687bc27 commit eb942b8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
28 changes: 28 additions & 0 deletions examples/research_projects/diffusion_dpo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
--push_to_hub
```

## SDXL Turbo training command

```bash
accelerate launch train_diffusion_dpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-turbo-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--is_turbo --resolution 512 \
--push_to_hub
```


## Acknowledgements

This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
images = []
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()

guidance_scale = 5.0
num_inference_steps = 25
if args.is_turbo:
guidance_scale = 0.0
num_inference_steps = 4
for prompt in VALIDATION_PROMPTS:
with context:
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
image = pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
images.append(image)

tracker_key = "test" if is_final_validation else "validation"
Expand All @@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
if is_final_validation:
pipeline.disable_lora()
no_lora_images = [
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
for prompt in VALIDATION_PROMPTS
]

for tracker in accelerator.trackers:
Expand Down Expand Up @@ -423,6 +433,11 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--is_turbo",
action="store_true",
help=("Use if tuning SDXL Turbo instead of SDXL"),
)
parser.add_argument(
"--rank",
type=int,
Expand All @@ -444,6 +459,9 @@ def parse_args(input_args=None):
if args.dataset_name is None:
raise ValueError("Must provide a `dataset_name`.")

if args.is_turbo:
assert "turbo" in args.pretrained_model_name_or_path

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
Expand Down Expand Up @@ -560,6 +578,36 @@ def main(args):

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

def enforce_zero_terminal_snr(scheduler):
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
# Turbo needs zero terminal SNR
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
# Convert betas to alphas_bar_sqrt
alphas = 1 - scheduler.betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])

alphas_cumprod = torch.cumprod(alphas, dim=0)
scheduler.alphas_cumprod = alphas_cumprod
return

if args.is_turbo:
enforce_zero_terminal_snr(noise_scheduler)

text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
Expand Down Expand Up @@ -909,6 +957,10 @@ def collate_fn(examples):
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
).repeat(2)
if args.is_turbo:
# Learn a 4 timestep schedule
timesteps_0_to_3 = timesteps % 4
timesteps = 250 * timesteps_0_to_3 + 249

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down

0 comments on commit eb942b8

Please sign in to comment.