Skip to content

Commit

Permalink
Merge pull request #67 from okotaku/feat/distill_sd
Browse files Browse the repository at this point in the history
[Feature] Support Distill SDXL training
  • Loading branch information
okotaku authored Oct 11, 2023
2 parents 530e093 + b0e7063 commit 476596d
Show file tree
Hide file tree
Showing 11 changed files with 583 additions and 4 deletions.
6 changes: 6 additions & 0 deletions configs/_base_/models/small_sd_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model = dict(
type='DistillSDXL',
model='stabilityai/stable-diffusion-xl-base-1.0',
vae_model='madebyollin/sdxl-vae-fp16-fix',
model_type='sd_small',
gradient_checkpointing=True)
6 changes: 6 additions & 0 deletions configs/_base_/models/tiny_sd_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model = dict(
type='DistillSDXL',
model='stabilityai/stable-diffusion-xl-base-1.0',
vae_model='madebyollin/sdxl-vae-fp16-fix',
model_type='sd_tiny',
gradient_checkpointing=True)
77 changes: 77 additions & 0 deletions configs/distill_sd/RAEDME.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Distill SD XL

[On Architectural Compression of Text-to-Image Diffusion Models](https://arxiv.org/abs/2305.15798)

## Abstract

Exceptional text-to-image (T2I) generation results of Stable Diffusion models (SDMs) come with substantial computational demands. To resolve this issue, recent research on efficient SDMs has prioritized reducing the number of sampling steps and utilizing network quantization. Orthogonal to these directions, this study highlights the power of classical architectural compression for general-purpose T2I synthesis by introducing block-removed knowledge-distilled SDMs (BK-SDMs). We eliminate several residual and attention blocks from the U-Net of SDMs, obtaining over a 30% reduction in the number of parameters, MACs per sampling step, and latency. We conduct distillation-based pretraining with only 0.22M LAION pairs (fewer than 0.1% of the full training pairs) on a single A100 GPU. Despite being trained with limited resources, our compact models can imitate the original SDM by benefiting from transferred knowledge and achieve competitive results against larger multi-billion parameter models on the zero-shot MS-COCO benchmark. Moreover, we demonstrate the applicability of our lightweight pretrained models in personalized generation with DreamBooth finetuning.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/253c0dfb-fa1c-4cbf-81c0-9d6948d40413"/>
</div>

## Citation

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
# Example.
$ mim train diffengine configs/distill_sd/small_sd_xl_pokemon_blip.py
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ mim run diffengine publish_model2diffusers configs/distill_sd/small_sd_xl_pokemon_blip.py work_dirs/small_sd_xl_pokemon_blip/epoch_50.pth work_dirs/small_sd_xl_pokemon_blip --save-keys unet
```

Then we can run inference.

```py
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL

prompt = 'yoda pokemon'
checkpoint = 'work_dirs/small_sd_xl_pokemon_blip'

unet = UNet2DConditionModel.from_pretrained(
checkpoint, subfolder='unet', torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', unet=unet, vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
num_inference_steps=50,
width=1024,
height=1024,
).images[0]
image.save('demo.png')
```

## Results Example

#### small_sd_xl_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/da9d56a5-04d7-4fba-9c88-6b13c86adb9f)

#### tiny_sd_xl_pokemon_blip

![example1](https://github.com/okotaku/diffengine/assets/24734142/5ae252e7-ecb2-4af6-bf9a-e68d0f1840ce)
5 changes: 5 additions & 0 deletions configs/distill_sd/small_sd_xl_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/small_sd_xl.py', '../_base_/datasets/pokemon_blip_xl.py',
'../_base_/schedules/stable_diffusion_xl_50e.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/distill_sd/tiny_sd_xl_pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/tiny_sd_xl.py', '../_base_/datasets/pokemon_blip_xl.py',
'../_base_/schedules/stable_diffusion_xl_50e.py',
'../_base_/default_runtime.py'
]
1 change: 1 addition & 0 deletions diffengine/models/editors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .distill_sd import * # noqa: F401, F403
from .esd import * # noqa: F401, F403
from .ip_adapter import * # noqa: F401, F403
from .stable_diffusion import * # noqa: F401, F403
Expand Down
3 changes: 3 additions & 0 deletions diffengine/models/editors/distill_sd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .distill_sd_xl import DistillSDXL

__all__ = ['DistillSDXL']
283 changes: 283 additions & 0 deletions diffengine/models/editors/distill_sd/distill_sd_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
import gc
from copy import deepcopy
from typing import Optional

import torch

from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
from diffengine.models.losses.snr_l2_loss import SNRL2Loss
from diffengine.registry import MODELS


@MODELS.register_module()
class DistillSDXL(StableDiffusionXL):
"""Distill Stable Diffusion XL.
Args:
model_type (str): The type of model to use. Choice from `sd_tiny`,
`sd_small`.
"""

def __init__(self,
*args,
model_type: str,
lora_config: Optional[dict] = None,
finetune_text_encoder: bool = False,
**kwargs):
assert lora_config is None, \
'`lora_config` should be None when training DistillSDXL'
assert not finetune_text_encoder, \
'`finetune_text_encoder` should be False when training DistillSDXL'
assert model_type in ['sd_tiny', 'sd_small'], \
f'`model_type`={model_type} should not be supported in DistillSDXL'

self.model_type = model_type

super().__init__(
*args,
lora_config=lora_config,
finetune_text_encoder=finetune_text_encoder,
**kwargs)

def set_lora(self):
"""Set LORA for model."""
pass

def prepare_model(self):
"""Prepare model for training.
Disable gradient for some models.
"""
self.orig_unet = deepcopy(self.unet).requires_grad_(False)

# prepare student model
self._prepare_student()
super().prepare_model()
self._cast_hook()

def _prepare_student(self):
assert len(self.unet.up_blocks) == len(self.unet.down_blocks)
self.num_blocks = len(self.unet.up_blocks)
config = self.unet._internal_dict
config['layers_per_block'] = 1
setattr(self.unet._internal_dict, 'layers_per_block', 1)
if self.model_type == 'sd_tiny':
self.unet.mid_block = None
config['mid_block_type'] = None

# Commence deletion of resnets/attentions inside the U-net
# Handle Down Blocks
for i in range(self.num_blocks):
delattr(self.unet.down_blocks[i].resnets, '1')
if hasattr(self.unet.down_blocks[i], 'attentions'):
# i == 0 does not have attentions
delattr(self.unet.down_blocks[i].attentions, '1')

for i in range(self.num_blocks):
self.unet.up_blocks[i].resnets[1] = self.unet.up_blocks[i].resnets[
2]
delattr(self.unet.up_blocks[i].resnets, '2')
if hasattr(self.unet.up_blocks[i], 'attentions'):
self.unet.up_blocks[i].attentions[1] = self.unet.up_blocks[
i].attentions[2]
delattr(self.unet.up_blocks[i].attentions, '2')

torch.cuda.empty_cache()
gc.collect()

def _cast_hook(self):
self.teacher_feats = {}
self.student_feats = {}

def getActivation(activation, name, residuals_present):
# the hook signature
if residuals_present:

def hook(model, input, output):
activation[name] = output[0]
else:

def hook(model, input, output):
activation[name] = output

return hook

# cast teacher
for i in range(self.num_blocks):
self.orig_unet.down_blocks[i].register_forward_hook(
getActivation(self.teacher_feats, 'd' + str(i), True))
self.orig_unet.mid_block.register_forward_hook(
getActivation(self.teacher_feats, 'm', False))
for i in range(self.num_blocks):
self.orig_unet.up_blocks[i].register_forward_hook(
getActivation(self.teacher_feats, 'u' + str(i), False))

# cast student
for i in range(self.num_blocks):
self.unet.down_blocks[i].register_forward_hook(
getActivation(self.student_feats, 'd' + str(i), True))
if self.model_type == 'sd_small':
self.unet.mid_block.register_forward_hook(
getActivation(self.student_feats, 'm', False))
for i in range(self.num_blocks):
self.unet.up_blocks[i].register_forward_hook(
getActivation(self.student_feats, 'u' + str(i), False))

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'loss'):
assert mode == 'loss'
num_batches = len(inputs['img'])
if 'result_class_image' in inputs:
# use prior_loss_weight
weight = torch.cat([
torch.ones((num_batches // 2, )),
torch.ones((num_batches // 2, )) * self.prior_loss_weight
]).float().reshape(-1, 1, 1, 1)
else:
weight = None

latents = self.vae.encode(inputs['img']).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=noise.device)

timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps, (num_batches, ),
device=self.device)
timesteps = timesteps.long()

noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)

if not self.pre_compute_text_embeddings:
inputs['text_one'] = self.tokenizer_one(
inputs['text'],
max_length=self.tokenizer_one.model_max_length,
padding='max_length',
truncation=True,
return_tensors='pt').input_ids.to(self.device)
inputs['text_two'] = self.tokenizer_two(
inputs['text'],
max_length=self.tokenizer_two.model_max_length,
padding='max_length',
truncation=True,
return_tensors='pt').input_ids.to(self.device)
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
inputs['text_one'], inputs['text_two'])
else:
prompt_embeds = inputs['prompt_embeds']
pooled_prompt_embeds = inputs['pooled_prompt_embeds']
unet_added_conditions = {
'time_ids': inputs['time_ids'],
'text_embeds': pooled_prompt_embeds
}

if self.scheduler.config.prediction_type == 'epsilon':
gt = noise
elif self.scheduler.config.prediction_type == 'v_prediction':
gt = self.scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError('Unknown prediction type '
f'{self.scheduler.config.prediction_type}')

model_pred = self.unet(
noisy_latents,
timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions).sample

with torch.no_grad():
teacher_pred = self.orig_unet(
noisy_latents,
timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions).sample

loss_dict = dict()
# calculate loss in FP32
if isinstance(self.loss_module, SNRL2Loss):
loss_features = 0
num_blocks = (
self.num_blocks
if self.model_type == 'sd_small' else self.num_blocks - 1)
for i in range(num_blocks):
loss_features = loss_features + self.loss_module(
self.teacher_feats['d' + str(i)].float(),
self.student_feats['d' + str(i)].float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)
if self.model_type == 'sd_small':
loss_features = loss_features + self.loss_module(
self.teacher_feats['m'].float(),
self.student_feats['m'].float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)
elif self.model_type == 'sd_tiny':
loss_features = loss_features + self.loss_module(
self.teacher_feats['m'].float(),
self.student_feats[f'd{self.num_blocks - 1}'].float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)
for i in range(self.num_blocks):
loss_features = loss_features + self.loss_module(
self.teacher_feats['u' + str(i)].float(),
self.student_feats['u' + str(i)].float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)

loss = self.loss_module(
model_pred.float(),
gt.float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)
loss_kd = self.loss_module(
model_pred.float(),
teacher_pred.float(),
timesteps,
self.scheduler.alphas_cumprod,
weight=weight)
else:
loss_features = 0
num_blocks = (
self.num_blocks
if self.model_type == 'sd_small' else self.num_blocks - 1)
for i in range(num_blocks):
loss_features = loss_features + self.loss_module(
self.teacher_feats['d' + str(i)].float(),
self.student_feats['d' + str(i)].float(),
weight=weight)
if self.model_type == 'sd_small':
loss_features = loss_features + self.loss_module(
self.teacher_feats['m'].float(),
self.student_feats['m'].float(),
weight=weight)
elif self.model_type == 'sd_tiny':
loss_features = loss_features + self.loss_module(
self.teacher_feats['m'].float(),
self.student_feats[f'd{self.num_blocks - 1}'].float(),
weight=weight)
for i in range(self.num_blocks):
loss_features = loss_features + self.loss_module(
self.teacher_feats['u' + str(i)].float(),
self.student_feats['u' + str(i)].float(),
weight=weight)

loss = self.loss_module(
model_pred.float(), gt.float(), weight=weight)
loss_kd = self.loss_module(
model_pred.float(), teacher_pred.float(), weight=weight)
loss_dict['loss_sd'] = loss
loss_dict['loss_kd'] = loss_kd
loss_dict['loss_features'] = loss_features
return loss_dict
Loading

0 comments on commit 476596d

Please sign in to comment.