generated from okotaku/template
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from okotaku/feat/distill_sd
[Feature] Support Distill SDXL training
- Loading branch information
Showing
11 changed files
with
583 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model = dict( | ||
type='DistillSDXL', | ||
model='stabilityai/stable-diffusion-xl-base-1.0', | ||
vae_model='madebyollin/sdxl-vae-fp16-fix', | ||
model_type='sd_small', | ||
gradient_checkpointing=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
model = dict( | ||
type='DistillSDXL', | ||
model='stabilityai/stable-diffusion-xl-base-1.0', | ||
vae_model='madebyollin/sdxl-vae-fp16-fix', | ||
model_type='sd_tiny', | ||
gradient_checkpointing=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Distill SD XL | ||
|
||
[On Architectural Compression of Text-to-Image Diffusion Models](https://arxiv.org/abs/2305.15798) | ||
|
||
## Abstract | ||
|
||
Exceptional text-to-image (T2I) generation results of Stable Diffusion models (SDMs) come with substantial computational demands. To resolve this issue, recent research on efficient SDMs has prioritized reducing the number of sampling steps and utilizing network quantization. Orthogonal to these directions, this study highlights the power of classical architectural compression for general-purpose T2I synthesis by introducing block-removed knowledge-distilled SDMs (BK-SDMs). We eliminate several residual and attention blocks from the U-Net of SDMs, obtaining over a 30% reduction in the number of parameters, MACs per sampling step, and latency. We conduct distillation-based pretraining with only 0.22M LAION pairs (fewer than 0.1% of the full training pairs) on a single A100 GPU. Despite being trained with limited resources, our compact models can imitate the original SDM by benefiting from transferred knowledge and achieve competitive results against larger multi-billion parameter models on the zero-shot MS-COCO benchmark. Moreover, we demonstrate the applicability of our lightweight pretrained models in personalized generation with DreamBooth finetuning. | ||
|
||
<div align=center> | ||
<img src="https://github.com/okotaku/diffengine/assets/24734142/253c0dfb-fa1c-4cbf-81c0-9d6948d40413"/> | ||
</div> | ||
|
||
## Citation | ||
|
||
## Run Training | ||
|
||
Run Training | ||
|
||
``` | ||
# single gpu | ||
$ mim train diffengine ${CONFIG_FILE} | ||
# multi gpus | ||
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch | ||
# Example. | ||
$ mim train diffengine configs/distill_sd/small_sd_xl_pokemon_blip.py | ||
``` | ||
|
||
## Inference with diffusers | ||
|
||
Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module. | ||
|
||
Before inferencing, we should convert weights for diffusers format, | ||
|
||
```bash | ||
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS} | ||
# Example | ||
$ mim run diffengine publish_model2diffusers configs/distill_sd/small_sd_xl_pokemon_blip.py work_dirs/small_sd_xl_pokemon_blip/epoch_50.pth work_dirs/small_sd_xl_pokemon_blip --save-keys unet | ||
``` | ||
|
||
Then we can run inference. | ||
|
||
```py | ||
import torch | ||
from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL | ||
|
||
prompt = 'yoda pokemon' | ||
checkpoint = 'work_dirs/small_sd_xl_pokemon_blip' | ||
|
||
unet = UNet2DConditionModel.from_pretrained( | ||
checkpoint, subfolder='unet', torch_dtype=torch.float16) | ||
vae = AutoencoderKL.from_pretrained( | ||
'madebyollin/sdxl-vae-fp16-fix', | ||
torch_dtype=torch.float16, | ||
) | ||
pipe = DiffusionPipeline.from_pretrained( | ||
'stabilityai/stable-diffusion-xl-base-1.0', unet=unet, vae=vae, torch_dtype=torch.float16) | ||
pipe.to('cuda') | ||
|
||
image = pipe( | ||
prompt, | ||
num_inference_steps=50, | ||
width=1024, | ||
height=1024, | ||
).images[0] | ||
image.save('demo.png') | ||
``` | ||
|
||
## Results Example | ||
|
||
#### small_sd_xl_pokemon_blip | ||
|
||
![example1](https://github.com/okotaku/diffengine/assets/24734142/da9d56a5-04d7-4fba-9c88-6b13c86adb9f) | ||
|
||
#### tiny_sd_xl_pokemon_blip | ||
|
||
![example1](https://github.com/okotaku/diffengine/assets/24734142/5ae252e7-ecb2-4af6-bf9a-e68d0f1840ce) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/small_sd_xl.py', '../_base_/datasets/pokemon_blip_xl.py', | ||
'../_base_/schedules/stable_diffusion_xl_50e.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/tiny_sd_xl.py', '../_base_/datasets/pokemon_blip_xl.py', | ||
'../_base_/schedules/stable_diffusion_xl_50e.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .distill_sd_xl import DistillSDXL | ||
|
||
__all__ = ['DistillSDXL'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
import gc | ||
from copy import deepcopy | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL | ||
from diffengine.models.losses.snr_l2_loss import SNRL2Loss | ||
from diffengine.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class DistillSDXL(StableDiffusionXL): | ||
"""Distill Stable Diffusion XL. | ||
Args: | ||
model_type (str): The type of model to use. Choice from `sd_tiny`, | ||
`sd_small`. | ||
""" | ||
|
||
def __init__(self, | ||
*args, | ||
model_type: str, | ||
lora_config: Optional[dict] = None, | ||
finetune_text_encoder: bool = False, | ||
**kwargs): | ||
assert lora_config is None, \ | ||
'`lora_config` should be None when training DistillSDXL' | ||
assert not finetune_text_encoder, \ | ||
'`finetune_text_encoder` should be False when training DistillSDXL' | ||
assert model_type in ['sd_tiny', 'sd_small'], \ | ||
f'`model_type`={model_type} should not be supported in DistillSDXL' | ||
|
||
self.model_type = model_type | ||
|
||
super().__init__( | ||
*args, | ||
lora_config=lora_config, | ||
finetune_text_encoder=finetune_text_encoder, | ||
**kwargs) | ||
|
||
def set_lora(self): | ||
"""Set LORA for model.""" | ||
pass | ||
|
||
def prepare_model(self): | ||
"""Prepare model for training. | ||
Disable gradient for some models. | ||
""" | ||
self.orig_unet = deepcopy(self.unet).requires_grad_(False) | ||
|
||
# prepare student model | ||
self._prepare_student() | ||
super().prepare_model() | ||
self._cast_hook() | ||
|
||
def _prepare_student(self): | ||
assert len(self.unet.up_blocks) == len(self.unet.down_blocks) | ||
self.num_blocks = len(self.unet.up_blocks) | ||
config = self.unet._internal_dict | ||
config['layers_per_block'] = 1 | ||
setattr(self.unet._internal_dict, 'layers_per_block', 1) | ||
if self.model_type == 'sd_tiny': | ||
self.unet.mid_block = None | ||
config['mid_block_type'] = None | ||
|
||
# Commence deletion of resnets/attentions inside the U-net | ||
# Handle Down Blocks | ||
for i in range(self.num_blocks): | ||
delattr(self.unet.down_blocks[i].resnets, '1') | ||
if hasattr(self.unet.down_blocks[i], 'attentions'): | ||
# i == 0 does not have attentions | ||
delattr(self.unet.down_blocks[i].attentions, '1') | ||
|
||
for i in range(self.num_blocks): | ||
self.unet.up_blocks[i].resnets[1] = self.unet.up_blocks[i].resnets[ | ||
2] | ||
delattr(self.unet.up_blocks[i].resnets, '2') | ||
if hasattr(self.unet.up_blocks[i], 'attentions'): | ||
self.unet.up_blocks[i].attentions[1] = self.unet.up_blocks[ | ||
i].attentions[2] | ||
delattr(self.unet.up_blocks[i].attentions, '2') | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
def _cast_hook(self): | ||
self.teacher_feats = {} | ||
self.student_feats = {} | ||
|
||
def getActivation(activation, name, residuals_present): | ||
# the hook signature | ||
if residuals_present: | ||
|
||
def hook(model, input, output): | ||
activation[name] = output[0] | ||
else: | ||
|
||
def hook(model, input, output): | ||
activation[name] = output | ||
|
||
return hook | ||
|
||
# cast teacher | ||
for i in range(self.num_blocks): | ||
self.orig_unet.down_blocks[i].register_forward_hook( | ||
getActivation(self.teacher_feats, 'd' + str(i), True)) | ||
self.orig_unet.mid_block.register_forward_hook( | ||
getActivation(self.teacher_feats, 'm', False)) | ||
for i in range(self.num_blocks): | ||
self.orig_unet.up_blocks[i].register_forward_hook( | ||
getActivation(self.teacher_feats, 'u' + str(i), False)) | ||
|
||
# cast student | ||
for i in range(self.num_blocks): | ||
self.unet.down_blocks[i].register_forward_hook( | ||
getActivation(self.student_feats, 'd' + str(i), True)) | ||
if self.model_type == 'sd_small': | ||
self.unet.mid_block.register_forward_hook( | ||
getActivation(self.student_feats, 'm', False)) | ||
for i in range(self.num_blocks): | ||
self.unet.up_blocks[i].register_forward_hook( | ||
getActivation(self.student_feats, 'u' + str(i), False)) | ||
|
||
def forward(self, | ||
inputs: torch.Tensor, | ||
data_samples: Optional[list] = None, | ||
mode: str = 'loss'): | ||
assert mode == 'loss' | ||
num_batches = len(inputs['img']) | ||
if 'result_class_image' in inputs: | ||
# use prior_loss_weight | ||
weight = torch.cat([ | ||
torch.ones((num_batches // 2, )), | ||
torch.ones((num_batches // 2, )) * self.prior_loss_weight | ||
]).float().reshape(-1, 1, 1, 1) | ||
else: | ||
weight = None | ||
|
||
latents = self.vae.encode(inputs['img']).latent_dist.sample() | ||
latents = latents * self.vae.config.scaling_factor | ||
|
||
noise = torch.randn_like(latents) | ||
|
||
if self.enable_noise_offset: | ||
noise = noise + self.noise_offset_weight * torch.randn( | ||
latents.shape[0], latents.shape[1], 1, 1, device=noise.device) | ||
|
||
timesteps = torch.randint( | ||
0, | ||
self.scheduler.config.num_train_timesteps, (num_batches, ), | ||
device=self.device) | ||
timesteps = timesteps.long() | ||
|
||
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
|
||
if not self.pre_compute_text_embeddings: | ||
inputs['text_one'] = self.tokenizer_one( | ||
inputs['text'], | ||
max_length=self.tokenizer_one.model_max_length, | ||
padding='max_length', | ||
truncation=True, | ||
return_tensors='pt').input_ids.to(self.device) | ||
inputs['text_two'] = self.tokenizer_two( | ||
inputs['text'], | ||
max_length=self.tokenizer_two.model_max_length, | ||
padding='max_length', | ||
truncation=True, | ||
return_tensors='pt').input_ids.to(self.device) | ||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt( | ||
inputs['text_one'], inputs['text_two']) | ||
else: | ||
prompt_embeds = inputs['prompt_embeds'] | ||
pooled_prompt_embeds = inputs['pooled_prompt_embeds'] | ||
unet_added_conditions = { | ||
'time_ids': inputs['time_ids'], | ||
'text_embeds': pooled_prompt_embeds | ||
} | ||
|
||
if self.scheduler.config.prediction_type == 'epsilon': | ||
gt = noise | ||
elif self.scheduler.config.prediction_type == 'v_prediction': | ||
gt = self.scheduler.get_velocity(latents, noise, timesteps) | ||
else: | ||
raise ValueError('Unknown prediction type ' | ||
f'{self.scheduler.config.prediction_type}') | ||
|
||
model_pred = self.unet( | ||
noisy_latents, | ||
timesteps, | ||
prompt_embeds, | ||
added_cond_kwargs=unet_added_conditions).sample | ||
|
||
with torch.no_grad(): | ||
teacher_pred = self.orig_unet( | ||
noisy_latents, | ||
timesteps, | ||
prompt_embeds, | ||
added_cond_kwargs=unet_added_conditions).sample | ||
|
||
loss_dict = dict() | ||
# calculate loss in FP32 | ||
if isinstance(self.loss_module, SNRL2Loss): | ||
loss_features = 0 | ||
num_blocks = ( | ||
self.num_blocks | ||
if self.model_type == 'sd_small' else self.num_blocks - 1) | ||
for i in range(num_blocks): | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['d' + str(i)].float(), | ||
self.student_feats['d' + str(i)].float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
if self.model_type == 'sd_small': | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['m'].float(), | ||
self.student_feats['m'].float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
elif self.model_type == 'sd_tiny': | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['m'].float(), | ||
self.student_feats[f'd{self.num_blocks - 1}'].float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
for i in range(self.num_blocks): | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['u' + str(i)].float(), | ||
self.student_feats['u' + str(i)].float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
|
||
loss = self.loss_module( | ||
model_pred.float(), | ||
gt.float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
loss_kd = self.loss_module( | ||
model_pred.float(), | ||
teacher_pred.float(), | ||
timesteps, | ||
self.scheduler.alphas_cumprod, | ||
weight=weight) | ||
else: | ||
loss_features = 0 | ||
num_blocks = ( | ||
self.num_blocks | ||
if self.model_type == 'sd_small' else self.num_blocks - 1) | ||
for i in range(num_blocks): | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['d' + str(i)].float(), | ||
self.student_feats['d' + str(i)].float(), | ||
weight=weight) | ||
if self.model_type == 'sd_small': | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['m'].float(), | ||
self.student_feats['m'].float(), | ||
weight=weight) | ||
elif self.model_type == 'sd_tiny': | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['m'].float(), | ||
self.student_feats[f'd{self.num_blocks - 1}'].float(), | ||
weight=weight) | ||
for i in range(self.num_blocks): | ||
loss_features = loss_features + self.loss_module( | ||
self.teacher_feats['u' + str(i)].float(), | ||
self.student_feats['u' + str(i)].float(), | ||
weight=weight) | ||
|
||
loss = self.loss_module( | ||
model_pred.float(), gt.float(), weight=weight) | ||
loss_kd = self.loss_module( | ||
model_pred.float(), teacher_pred.float(), weight=weight) | ||
loss_dict['loss_sd'] = loss | ||
loss_dict['loss_kd'] = loss_kd | ||
loss_dict['loss_features'] = loss_features | ||
return loss_dict |
Oops, something went wrong.