-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support EDM-style training in DreamBooth LoRA SDXL script #7126
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | ||
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) | ||
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) | ||
timesteps = timesteps.to(accelerator.device) | ||
|
||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
|
||
sigma = sigmas[step_indices].flatten() | ||
while len(sigma.shape) < n_dim: | ||
sigma = sigma.unsqueeze(-1) | ||
return sigma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For later:
We could think of making this more general by allowing to sample sigmas as presented in the paper, cf https://github.com/NVlabs/edm/blob/main/training/loss.py#L74
)[0] | ||
|
||
model_pred = model_pred * (-sigmas) + noisy_model_input | ||
weighing = sigmas**-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For later: this could be made configurable, as there are multiple weighing alternatives. In EDM they use
https://github.com/NVlabs/edm/blob/main/training/loss.py#L75
@patil-suraj could you give this another look? Results are still blank: https://wandb.ai/sayakpaul/dreambooth-lora-playground/runs/i7aq50g0. Experimenting with a lower LR. |
Can't seem to find anything else, will also try to run the script and see what's going on. |
Got a good run: https://wandb.ai/psuraj/dreambooth-lora-playground/runs/j34izml0?workspace=user-psuraj (still going on) What fixed it:
|
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Applied the changes, @patil-suraj. Could you try another run? |
Started a new run here https://wandb.ai/psuraj/dreambooth-lora-playground/runs/ef2qkmre |
@patil-suraj ready for a review. Feel free to test the script too :) @pcuenca feel free to give this a review as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. I'd maybe try to avoid hardcoded references to the string "playgroundai"
to make decisions, if possible.
|
||
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364). | ||
|
||
For the SDXL model, simple set: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the SDXL model, simple set: | |
For the standard SDXL model, simply set: |
Does it work with SDXL out of the box? 🤯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a test that you can check but I haven't done a full-blown training run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For LoRA it might not work, but can def be fine-tuned with EDM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, left some comments. +1 to what pedro said.
|
||
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364). | ||
|
||
For the SDXL model, simple set: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For LoRA it might not work, but can def be fine-tuned with EDM.
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@patil-suraj ready for another review. |
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@pcuenca @patil-suraj I have addressed all your comments. Would appreciate another review. I am going to run with the command from the OP one more time and also with regular SDXL with |
Started a regular SDXL run with EDM: CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--instance_data_dir="dog"\
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
--output_dir="lora-sdxl-dog" \
--mixed_precision="fp16" \
--use_8bit_adam \
--do_edm_style_training \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" Garbage results: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/bup8u1yc. Let me further tweak around some things. |
@pcuenca @patil-suraj the script now should work out of the box when CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--instance_data_dir="dog"\
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
--output_dir="lora-sdxl-dog" \
--mixed_precision="fp16" \
--use_8bit_adam \
--do_edm_style_training \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" Feel free to train one yourselves. Here are my results: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/dz77sffl Please review the changes so that we can ship this beast! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments. The script is in a very good state for EDM. I would just suggest to verify the euler bit before adding it here or maybe even do it in another PR. (saw the other comment, all good)
Also are the vae
weights loaded and kept in fp32
?
# There might be other alternatives for weighting as well: | ||
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686 | ||
if "EDM" not in scheduler_type: | ||
weighting = (sigmas**-2.0).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should verify if this works with euler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/dz77sffl. When do_edm_style_training
is True and the scheduler is not EDM*, we are using EulerDiscrete. The run is from that setting.
Does that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
Yes, that is the case. I have addressed your other comment as well, @patil-suraj. LMK. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great now! Feel free to merge
That is an issue quite unrelated to this PR. |
if args.do_edm_style_training and args.snr_gamma is not None: | ||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do this earlier, so it doesn't load the model yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's at the beginning:
if args.do_edm_style_training and args.snr_gamma is not None: |
way before the model loading code.
Command example:
WandB: https://wandb.ai/sayakpaul/dreambooth-lora-playground/runs/sxe4bavp