Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support ControlNet #31

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ For detailed user guides and advanced guides, please refer to our [Documentation
- [Run Stable Diffusion XL DreamBooth](docs/source/run_guides/run_dreambooth_xl.md)
- [Run Stable Diffusion LoRA](docs/source/run_guides/run_lora.md)
- [Run Stable Diffusion XL LoRA](docs/source/run_guides/run_lora_xl.md)
- [Run Stable Diffusion ControlNet](docs/source/run_guides/run_controlnet.md)
- [Run Stable Diffusion XL ControlNet](docs/source/run_guides/run_controlnet_xl.md)
- [Inference](docs/source/run_guides/inference.md)

</details>
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/dog_dreambooth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
train_pipeline = [
dict(type='torchvision/Resize', size=512, interpolation='bilinear'),
dict(type='torchvision/RandomCrop', size=512),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='RandomCrop', size=512),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/dog_dreambooth_xl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
dict(type='RandomCropWithCropPoint', size=1024),
dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
dict(type='RandomCrop', size=1024),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
Expand Down
39 changes: 39 additions & 0 deletions configs/_base_/datasets/face_spiga_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
train_pipeline = [
dict(
type='torchvision/Resize',
size=512,
interpolation='bilinear',
keys=['img', 'condition_img']),
dict(type='RandomCrop', size=512, keys=['img', 'condition_img']),
dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs', input_keys=['img', 'condition_img', 'text']),
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
dataset=dict(
type='HFControlNetDataset',
dataset='multimodalart/facesyntheticsspigacaptioned',
condition_column='spiga_seg',
caption_column='image_caption',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type='VisualizationHook',
prompt=['a close up of a man with a mohawkcut and a purple shirt'] * 4,
condition_image=[
'https://datasets-server.huggingface.co/assets/multimodalart/facesyntheticsspigacaptioned/--/multimodalart--facesyntheticsspigacaptioned/train/1/spiga_seg/image.jpg' # noqa
] * 4),
dict(type='ControlNetSaveHook')
]
39 changes: 39 additions & 0 deletions configs/_base_/datasets/fill50k_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
train_pipeline = [
dict(
type='torchvision/Resize',
size=512,
interpolation='bilinear',
keys=['img', 'condition_img']),
dict(type='RandomCrop', size=512, keys=['img', 'condition_img']),
dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs', input_keys=['img', 'condition_img', 'text']),
]
train_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type='HFControlNetDataset',
dataset='fusing/fill50k',
condition_column='conditioning_image',
caption_column='text',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type='VisualizationHook',
prompt=['cyan circle with brown floral background'] * 4,
condition_image=[
'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg' # noqa
] * 4),
dict(type='ControlNetSaveHook')
]
43 changes: 43 additions & 0 deletions configs/_base_/datasets/fill50k_controlnet_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(
type='torchvision/Resize',
size=1024,
interpolation='bilinear',
keys=['img', 'condition_img']),
dict(type='RandomCrop', size=1024, keys=['img', 'condition_img']),
dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(
type='PackInputs',
input_keys=['img', 'condition_img', 'text', 'time_ids']),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type='HFControlNetDataset',
dataset='fusing/fill50k',
condition_column='conditioning_image',
caption_column='text',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type='VisualizationHook',
prompt=['cyan circle with brown floral background'] * 4,
condition_image=[
'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg' # noqa
] * 4),
dict(type='ControlNetSaveHook')
]
4 changes: 2 additions & 2 deletions configs/_base_/datasets/keramer_face_dreambooth_xl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
dict(type='RandomCropWithCropPoint', size=1024),
dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
dict(type='RandomCrop', size=1024),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/pokemon_blip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
train_pipeline = [
dict(type='torchvision/Resize', size=512, interpolation='bilinear'),
dict(type='torchvision/RandomCrop', size=512),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='RandomCrop', size=512),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/pokemon_blip_xl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
dict(type='RandomCropWithCropPoint', size=1024),
dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
dict(type='RandomCrop', size=1024),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/potatohead_dreambooth_xl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
dict(type='RandomCropWithCropPoint', size=1024),
dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
dict(type='RandomCrop', size=1024),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/starbucks_dreambooth_xl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
dict(type='RandomCropWithCropPoint', size=1024),
dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
dict(type='RandomCrop', size=1024),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
Expand Down
2 changes: 2 additions & 0 deletions configs/_base_/models/stable_diffusion_v15_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model = dict(
type='StableDiffusionControlNet', model='runwayml/stable-diffusion-v1-5')
5 changes: 5 additions & 0 deletions configs/_base_/models/stable_diffusion_xl_controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model = dict(
type='StableDiffusionXLControlNet',
model='stabilityai/stable-diffusion-xl-base-1.0',
vae_model='madebyollin/sdxl-vae-fp16-fix',
gradient_checkpointing=True)
17 changes: 17 additions & 0 deletions configs/_base_/schedules/stable_diffusion_1e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(type='AdamW', lr=1e-4, weight_decay=1e-2),
clip_grad=dict(max_norm=1.0))

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=1)
val_cfg = None
test_cfg = None

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=1,
max_keep_ckpts=3,
), )
17 changes: 17 additions & 0 deletions configs/_base_/schedules/stable_diffusion_3e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(type='AdamW', lr=1e-4, weight_decay=1e-2),
clip_grad=dict(max_norm=1.0))

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=3)
val_cfg = None
test_cfg = None

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=1,
max_keep_ckpts=3,
), )
2 changes: 1 addition & 1 deletion configs/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ image = pipe(
image.save('demo.png')
```

You can see more details on [Run Stable Diffusion docs](../../docs/source/run_guides/run_sd.md#inference-with-diffusers).
You can see more details on [`docs/source/run_guides/run_sd.md`](../../docs/source/run_guides/run_sd.md#inference-with-diffusers).

## Results Example

Expand Down
87 changes: 87 additions & 0 deletions configs/stable_diffusion_controlnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Stable Diffusion ControlNet

[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543)

## Abstract

We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (\< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/97a5d6b7-90b9-4247-936c-c27e26b47cff"/>
</div>

## Citation

```
@misc{zhang2023adding,
title={Adding Conditional Control to Text-to-Image Diffusion Models},
author={Lvmin Zhang and Maneesh Agrawala},
year={2023},
eprint={2302.05543},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py
```

## Inference with diffusers

Once you have trained a model, specify the path to where the model is saved, and use it for inference with the `diffusers`.

```py
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image

checkpoint = 'work_dirs/stable_diffusion_v15_controlnet_fill50k/step6250'
prompt = 'cyan circle with brown floral background'
condition_image = load_image(
'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
)

controlnet = ControlNetModel.from_pretrained(
checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
'runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
condition_image,
num_inference_steps=50,
).images[0]
image.save('demo.png')
```

You can see more details on [`docs/source/run_guides/run_controlnet.md`](../../docs/source/run_guides/run_controlnet.md#inference-with-diffusers).

## Results Example

#### stable_diffusion_v15_controlnet_fill50k

![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)

![example1](https://github.com/okotaku/diffengine/assets/24734142/a14cc9a6-3a40-4577-bd5a-2ddbab60970d)

#### stable_diffusion_v15_controlnet_face_spiga

![input2](https://datasets-server.huggingface.co/assets/multimodalart/facesyntheticsspigacaptioned/--/multimodalart--facesyntheticsspigacaptioned/train/1/spiga_seg/image.jpg)

![example2](https://github.com/okotaku/diffengine/assets/24734142/172b7c7a-a5a0-493a-8bcf-2d6491f44f90)

## Acknowledgement

These experiments are based on [diffusers docs](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README.md) and [blog post `Train your ControlNet with diffusers 🧨`](https://huggingface.co/blog/train-your-controlnet). Thank you for the great articles.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/stable_diffusion_v15_controlnet.py',
'../_base_/datasets/face_spiga_controlnet.py',
'../_base_/schedules/stable_diffusion_3e.py',
'../_base_/default_runtime.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/stable_diffusion_v15_controlnet.py',
'../_base_/datasets/fill50k_controlnet.py',
'../_base_/schedules/stable_diffusion_1e.py',
'../_base_/default_runtime.py'
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

train_pipeline = [
dict(type='torchvision/Resize', size=768, interpolation='bilinear'),
dict(type='torchvision/RandomCrop', size=768),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='RandomCrop', size=768),
dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
Expand Down
Loading