diff --git a/README.md b/README.md
index 0c3618d..0f67e3e 100644
--- a/README.md
+++ b/README.md
@@ -110,6 +110,8 @@ For detailed user guides and advanced guides, please refer to our [Documentation
- [Run Stable Diffusion XL DreamBooth](docs/source/run_guides/run_dreambooth_xl.md)
- [Run Stable Diffusion LoRA](docs/source/run_guides/run_lora.md)
- [Run Stable Diffusion XL LoRA](docs/source/run_guides/run_lora_xl.md)
+- [Run Stable Diffusion ControlNet](docs/source/run_guides/run_controlnet.md)
+- [Run Stable Diffusion XL ControlNet](docs/source/run_guides/run_controlnet_xl.md)
- [Inference](docs/source/run_guides/inference.md)
diff --git a/configs/_base_/datasets/dog_dreambooth.py b/configs/_base_/datasets/dog_dreambooth.py
index 15c8f6d..f5067da 100644
--- a/configs/_base_/datasets/dog_dreambooth.py
+++ b/configs/_base_/datasets/dog_dreambooth.py
@@ -1,7 +1,7 @@
train_pipeline = [
dict(type='torchvision/Resize', size=512, interpolation='bilinear'),
- dict(type='torchvision/RandomCrop', size=512),
- dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+ dict(type='RandomCrop', size=512),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
diff --git a/configs/_base_/datasets/dog_dreambooth_xl.py b/configs/_base_/datasets/dog_dreambooth_xl.py
index ced4788..2e2acb8 100644
--- a/configs/_base_/datasets/dog_dreambooth_xl.py
+++ b/configs/_base_/datasets/dog_dreambooth_xl.py
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
- dict(type='RandomCropWithCropPoint', size=1024),
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
+ dict(type='RandomCrop', size=1024),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
diff --git a/configs/_base_/datasets/face_spiga_controlnet.py b/configs/_base_/datasets/face_spiga_controlnet.py
new file mode 100644
index 0000000..927fe7f
--- /dev/null
+++ b/configs/_base_/datasets/face_spiga_controlnet.py
@@ -0,0 +1,39 @@
+train_pipeline = [
+ dict(
+ type='torchvision/Resize',
+ size=512,
+ interpolation='bilinear',
+ keys=['img', 'condition_img']),
+ dict(type='RandomCrop', size=512, keys=['img', 'condition_img']),
+ dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
+ dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
+ dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
+ dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
+ dict(type='PackInputs', input_keys=['img', 'condition_img', 'text']),
+]
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ dataset=dict(
+ type='HFControlNetDataset',
+ dataset='multimodalart/facesyntheticsspigacaptioned',
+ condition_column='spiga_seg',
+ caption_column='image_caption',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+)
+
+val_dataloader = None
+val_evaluator = None
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
+
+custom_hooks = [
+ dict(
+ type='VisualizationHook',
+ prompt=['a close up of a man with a mohawkcut and a purple shirt'] * 4,
+ condition_image=[
+ 'https://datasets-server.huggingface.co/assets/multimodalart/facesyntheticsspigacaptioned/--/multimodalart--facesyntheticsspigacaptioned/train/1/spiga_seg/image.jpg' # noqa
+ ] * 4),
+ dict(type='ControlNetSaveHook')
+]
diff --git a/configs/_base_/datasets/fill50k_controlnet.py b/configs/_base_/datasets/fill50k_controlnet.py
new file mode 100644
index 0000000..deb7cb3
--- /dev/null
+++ b/configs/_base_/datasets/fill50k_controlnet.py
@@ -0,0 +1,39 @@
+train_pipeline = [
+ dict(
+ type='torchvision/Resize',
+ size=512,
+ interpolation='bilinear',
+ keys=['img', 'condition_img']),
+ dict(type='RandomCrop', size=512, keys=['img', 'condition_img']),
+ dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
+ dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
+ dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
+ dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
+ dict(type='PackInputs', input_keys=['img', 'condition_img', 'text']),
+]
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ dataset=dict(
+ type='HFControlNetDataset',
+ dataset='fusing/fill50k',
+ condition_column='conditioning_image',
+ caption_column='text',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+)
+
+val_dataloader = None
+val_evaluator = None
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
+
+custom_hooks = [
+ dict(
+ type='VisualizationHook',
+ prompt=['cyan circle with brown floral background'] * 4,
+ condition_image=[
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg' # noqa
+ ] * 4),
+ dict(type='ControlNetSaveHook')
+]
diff --git a/configs/_base_/datasets/fill50k_controlnet_xl.py b/configs/_base_/datasets/fill50k_controlnet_xl.py
new file mode 100644
index 0000000..a6457d4
--- /dev/null
+++ b/configs/_base_/datasets/fill50k_controlnet_xl.py
@@ -0,0 +1,43 @@
+train_pipeline = [
+ dict(type='SaveImageShape'),
+ dict(
+ type='torchvision/Resize',
+ size=1024,
+ interpolation='bilinear',
+ keys=['img', 'condition_img']),
+ dict(type='RandomCrop', size=1024, keys=['img', 'condition_img']),
+ dict(type='RandomHorizontalFlip', p=0.5, keys=['img', 'condition_img']),
+ dict(type='ComputeTimeIds'),
+ dict(type='torchvision/ToTensor', keys=['img', 'condition_img']),
+ dict(type='DumpImage', max_imgs=10, dump_dir='work_dirs/dump'),
+ dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
+ dict(
+ type='PackInputs',
+ input_keys=['img', 'condition_img', 'text', 'time_ids']),
+]
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+ dataset=dict(
+ type='HFControlNetDataset',
+ dataset='fusing/fill50k',
+ condition_column='conditioning_image',
+ caption_column='text',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+)
+
+val_dataloader = None
+val_evaluator = None
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
+
+custom_hooks = [
+ dict(
+ type='VisualizationHook',
+ prompt=['cyan circle with brown floral background'] * 4,
+ condition_image=[
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg' # noqa
+ ] * 4),
+ dict(type='ControlNetSaveHook')
+]
diff --git a/configs/_base_/datasets/keramer_face_dreambooth_xl.py b/configs/_base_/datasets/keramer_face_dreambooth_xl.py
index 5d03e17..25ba84a 100644
--- a/configs/_base_/datasets/keramer_face_dreambooth_xl.py
+++ b/configs/_base_/datasets/keramer_face_dreambooth_xl.py
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
- dict(type='RandomCropWithCropPoint', size=1024),
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
+ dict(type='RandomCrop', size=1024),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
diff --git a/configs/_base_/datasets/pokemon_blip.py b/configs/_base_/datasets/pokemon_blip.py
index 2f32797..1cb0fc0 100644
--- a/configs/_base_/datasets/pokemon_blip.py
+++ b/configs/_base_/datasets/pokemon_blip.py
@@ -1,7 +1,7 @@
train_pipeline = [
dict(type='torchvision/Resize', size=512, interpolation='bilinear'),
- dict(type='torchvision/RandomCrop', size=512),
- dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+ dict(type='RandomCrop', size=512),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
diff --git a/configs/_base_/datasets/pokemon_blip_xl.py b/configs/_base_/datasets/pokemon_blip_xl.py
index 910c57d..2f801cb 100644
--- a/configs/_base_/datasets/pokemon_blip_xl.py
+++ b/configs/_base_/datasets/pokemon_blip_xl.py
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
- dict(type='RandomCropWithCropPoint', size=1024),
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
+ dict(type='RandomCrop', size=1024),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
diff --git a/configs/_base_/datasets/potatohead_dreambooth_xl.py b/configs/_base_/datasets/potatohead_dreambooth_xl.py
index 941babd..d713aee 100644
--- a/configs/_base_/datasets/potatohead_dreambooth_xl.py
+++ b/configs/_base_/datasets/potatohead_dreambooth_xl.py
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
- dict(type='RandomCropWithCropPoint', size=1024),
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
+ dict(type='RandomCrop', size=1024),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
diff --git a/configs/_base_/datasets/starbucks_dreambooth_xl.py b/configs/_base_/datasets/starbucks_dreambooth_xl.py
index 473ae0f..e9e40de 100644
--- a/configs/_base_/datasets/starbucks_dreambooth_xl.py
+++ b/configs/_base_/datasets/starbucks_dreambooth_xl.py
@@ -1,8 +1,8 @@
train_pipeline = [
dict(type='SaveImageShape'),
dict(type='torchvision/Resize', size=1024, interpolation='bilinear'),
- dict(type='RandomCropWithCropPoint', size=1024),
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.5),
+ dict(type='RandomCrop', size=1024),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='ComputeTimeIds'),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
diff --git a/configs/_base_/models/stable_diffusion_v15_controlnet.py b/configs/_base_/models/stable_diffusion_v15_controlnet.py
new file mode 100644
index 0000000..66551af
--- /dev/null
+++ b/configs/_base_/models/stable_diffusion_v15_controlnet.py
@@ -0,0 +1,2 @@
+model = dict(
+ type='StableDiffusionControlNet', model='runwayml/stable-diffusion-v1-5')
diff --git a/configs/_base_/models/stable_diffusion_xl_controlnet.py b/configs/_base_/models/stable_diffusion_xl_controlnet.py
new file mode 100644
index 0000000..13857ba
--- /dev/null
+++ b/configs/_base_/models/stable_diffusion_xl_controlnet.py
@@ -0,0 +1,5 @@
+model = dict(
+ type='StableDiffusionXLControlNet',
+ model='stabilityai/stable-diffusion-xl-base-1.0',
+ vae_model='madebyollin/sdxl-vae-fp16-fix',
+ gradient_checkpointing=True)
diff --git a/configs/_base_/schedules/stable_diffusion_1e.py b/configs/_base_/schedules/stable_diffusion_1e.py
new file mode 100644
index 0000000..d88c1d1
--- /dev/null
+++ b/configs/_base_/schedules/stable_diffusion_1e.py
@@ -0,0 +1,17 @@
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ dtype='float16',
+ optimizer=dict(type='AdamW', lr=1e-4, weight_decay=1e-2),
+ clip_grad=dict(max_norm=1.0))
+
+# train, val, test setting
+train_cfg = dict(by_epoch=True, max_epochs=1)
+val_cfg = None
+test_cfg = None
+
+default_hooks = dict(
+ checkpoint=dict(
+ type='CheckpointHook',
+ interval=1,
+ max_keep_ckpts=3,
+ ), )
diff --git a/configs/_base_/schedules/stable_diffusion_3e.py b/configs/_base_/schedules/stable_diffusion_3e.py
new file mode 100644
index 0000000..de9a083
--- /dev/null
+++ b/configs/_base_/schedules/stable_diffusion_3e.py
@@ -0,0 +1,17 @@
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ dtype='float16',
+ optimizer=dict(type='AdamW', lr=1e-4, weight_decay=1e-2),
+ clip_grad=dict(max_norm=1.0))
+
+# train, val, test setting
+train_cfg = dict(by_epoch=True, max_epochs=3)
+val_cfg = None
+test_cfg = None
+
+default_hooks = dict(
+ checkpoint=dict(
+ type='CheckpointHook',
+ interval=1,
+ max_keep_ckpts=3,
+ ), )
diff --git a/configs/stable_diffusion/README.md b/configs/stable_diffusion/README.md
index 6a22ad5..5dd4ab7 100644
--- a/configs/stable_diffusion/README.md
+++ b/configs/stable_diffusion/README.md
@@ -50,7 +50,7 @@ image = pipe(
image.save('demo.png')
```
-You can see more details on [Run Stable Diffusion docs](../../docs/source/run_guides/run_sd.md#inference-with-diffusers).
+You can see more details on [`docs/source/run_guides/run_sd.md`](../../docs/source/run_guides/run_sd.md#inference-with-diffusers).
## Results Example
diff --git a/configs/stable_diffusion_controlnet/README.md b/configs/stable_diffusion_controlnet/README.md
new file mode 100644
index 0000000..534ce46
--- /dev/null
+++ b/configs/stable_diffusion_controlnet/README.md
@@ -0,0 +1,87 @@
+# Stable Diffusion ControlNet
+
+[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543)
+
+## Abstract
+
+We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (\< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.
+
+
+
![](https://github.com/okotaku/diffengine/assets/24734142/97a5d6b7-90b9-4247-936c-c27e26b47cff)
+
+
+## Citation
+
+```
+@misc{zhang2023adding,
+ title={Adding Conditional Control to Text-to-Image Diffusion Models},
+ author={Lvmin Zhang and Maneesh Agrawala},
+ year={2023},
+ eprint={2302.05543},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Run Training
+
+Run Training
+
+```
+# single gpu
+$ mim train diffengine ${CONFIG_FILE}
+# multi gpus
+$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
+
+# Example.
+$ mim train diffengine configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py
+```
+
+## Inference with diffusers
+
+Once you have trained a model, specify the path to where the model is saved, and use it for inference with the `diffusers`.
+
+```py
+import torch
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image
+
+checkpoint = 'work_dirs/stable_diffusion_v15_controlnet_fill50k/step6250'
+prompt = 'cyan circle with brown floral background'
+condition_image = load_image(
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
+)
+
+controlnet = ControlNetModel.from_pretrained(
+ checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ 'runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=torch.float16)
+pipe.to('cuda')
+
+image = pipe(
+ prompt,
+ condition_image,
+ num_inference_steps=50,
+).images[0]
+image.save('demo.png')
+```
+
+You can see more details on [`docs/source/run_guides/run_controlnet.md`](../../docs/source/run_guides/run_controlnet.md#inference-with-diffusers).
+
+## Results Example
+
+#### stable_diffusion_v15_controlnet_fill50k
+
+![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)
+
+![example1](https://github.com/okotaku/diffengine/assets/24734142/a14cc9a6-3a40-4577-bd5a-2ddbab60970d)
+
+#### stable_diffusion_v15_controlnet_face_spiga
+
+![input2](https://datasets-server.huggingface.co/assets/multimodalart/facesyntheticsspigacaptioned/--/multimodalart--facesyntheticsspigacaptioned/train/1/spiga_seg/image.jpg)
+
+![example2](https://github.com/okotaku/diffengine/assets/24734142/172b7c7a-a5a0-493a-8bcf-2d6491f44f90)
+
+## Acknowledgement
+
+These experiments are based on [diffusers docs](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README.md) and [blog post `Train your ControlNet with diffusers 🧨`](https://huggingface.co/blog/train-your-controlnet). Thank you for the great articles.
diff --git a/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_face_spiga.py b/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_face_spiga.py
new file mode 100644
index 0000000..9c5c1cd
--- /dev/null
+++ b/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_face_spiga.py
@@ -0,0 +1,6 @@
+_base_ = [
+ '../_base_/models/stable_diffusion_v15_controlnet.py',
+ '../_base_/datasets/face_spiga_controlnet.py',
+ '../_base_/schedules/stable_diffusion_3e.py',
+ '../_base_/default_runtime.py'
+]
diff --git a/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py b/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py
new file mode 100644
index 0000000..cdd857b
--- /dev/null
+++ b/configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py
@@ -0,0 +1,6 @@
+_base_ = [
+ '../_base_/models/stable_diffusion_v15_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
diff --git a/configs/stable_diffusion_dreambooth/stable_diffusion_v21_dreambooth_lora_dog.py b/configs/stable_diffusion_dreambooth/stable_diffusion_v21_dreambooth_lora_dog.py
index faa7032..df634dc 100644
--- a/configs/stable_diffusion_dreambooth/stable_diffusion_v21_dreambooth_lora_dog.py
+++ b/configs/stable_diffusion_dreambooth/stable_diffusion_v21_dreambooth_lora_dog.py
@@ -7,8 +7,8 @@
train_pipeline = [
dict(type='torchvision/Resize', size=768, interpolation='bilinear'),
- dict(type='torchvision/RandomCrop', size=768),
- dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+ dict(type='RandomCrop', size=768),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
diff --git a/configs/stable_diffusion_xl_controlnet/README.md b/configs/stable_diffusion_xl_controlnet/README.md
new file mode 100644
index 0000000..6c60503
--- /dev/null
+++ b/configs/stable_diffusion_xl_controlnet/README.md
@@ -0,0 +1,88 @@
+# Stable Diffusion XL ControlNet
+
+[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543)
+
+## Abstract
+
+We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (\< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.
+
+
+
![](https://github.com/okotaku/diffengine/assets/24734142/97a5d6b7-90b9-4247-936c-c27e26b47cff)
+
+
+## Citation
+
+```
+@misc{zhang2023adding,
+ title={Adding Conditional Control to Text-to-Image Diffusion Models},
+ author={Lvmin Zhang and Maneesh Agrawala},
+ year={2023},
+ eprint={2302.05543},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Run Training
+
+Run Training
+
+```
+# single gpu
+$ mim train diffengine ${CONFIG_FILE}
+# multi gpus
+$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
+
+# Example.
+$ mim train diffengine configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py
+```
+
+## Inference with diffusers
+
+Once you have trained a model, specify the path to where the model is saved, and use it for inference with the `diffusers`.
+
+```py
+import torch
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
+from diffusers.utils import load_image
+
+checkpoint = 'work_dirs/stable_diffusion_xl_controlnet_fill50k/step25000'
+prompt = 'cyan circle with brown floral background'
+condition_image = load_image(
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
+)
+
+controlnet = ControlNetModel.from_pretrained(
+ checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
+
+vae = AutoencoderKL.from_pretrained(
+ 'madebyollin/sdxl-vae-fp16-fix',
+ torch_dtype=torch.float16,
+)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ 'stabilityai/stable-diffusion-xl-base-1.0', controlnet=controlnet, vae=vae, torch_dtype=torch.float16)
+pipe.to('cuda')
+
+image = pipe(
+ prompt,
+ condition_image,
+ num_inference_steps=50,
+).images[0]
+image.save('demo.png')
+```
+
+You can see more details on [`docs/source/run_guides/run_controlnet_xl.md`](../../docs/source/run_guides/run_controlnet_xl.md#inference-with-diffusers).
+
+## Results Example
+
+#### stable_diffusion_xl_controlnet_fill50k
+
+![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)
+
+![example1](https://github.com/okotaku/diffengine/assets/24734142/a331a413-a9e7-4b9a-aa75-72279c4cc77a)
+
+Note that some of the results are not good. We should attempt further tuning.
+
+## Acknowledgement
+
+These experiments are based on [diffusers docs](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md). Thank you for the great articles.
diff --git a/configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py b/configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py
new file mode 100644
index 0000000..7833d66
--- /dev/null
+++ b/configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py
@@ -0,0 +1,11 @@
+_base_ = [
+ '../_base_/models/stable_diffusion_xl_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet_xl.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
+
+optim_wrapper = dict(
+ optimizer=dict(lr=1e-5),
+ accumulative_counts=2,
+)
diff --git a/diffengine/datasets/__init__.py b/diffengine/datasets/__init__.py
index ec3b3b8..26a1162 100644
--- a/diffengine/datasets/__init__.py
+++ b/diffengine/datasets/__init__.py
@@ -1,5 +1,6 @@
+from .hf_controlnet_datasets import HFControlNetDataset
from .hf_datasets import HFDataset
from .hf_dreambooth_datasets import HFDreamBoothDataset
from .transforms import * # noqa: F401, F403
-__all__ = ['HFDataset', 'HFDreamBoothDataset']
+__all__ = ['HFDataset', 'HFDreamBoothDataset', 'HFControlNetDataset']
diff --git a/diffengine/datasets/hf_controlnet_datasets.py b/diffengine/datasets/hf_controlnet_datasets.py
new file mode 100644
index 0000000..7075fa2
--- /dev/null
+++ b/diffengine/datasets/hf_controlnet_datasets.py
@@ -0,0 +1,96 @@
+import os
+import random
+from pathlib import Path
+from typing import Optional, Sequence
+
+import numpy as np
+from datasets import load_dataset
+from mmengine.dataset.base_dataset import Compose
+from PIL import Image
+from torch.utils.data import Dataset
+
+from diffengine.registry import DATASETS
+
+
+@DATASETS.register_module()
+class HFControlNetDataset(Dataset):
+ """Dataset for huggingface datasets.
+
+ Args:
+ dataset (str): Dataset name or path to dataset.
+ image_column (str): Image column name. Defaults to 'image'.
+ condition_column (str): Condition column name for ControlNet.
+ Defaults to 'condition'.
+ caption_column (str): Caption column name. Defaults to 'text'.
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ cache_dir (str, optional): The directory where the downloaded datasets
+ will be stored.Defaults to None.
+ """
+
+ def __init__(self,
+ dataset: str,
+ image_column: str = 'image',
+ condition_column: str = 'condition',
+ caption_column: str = 'text',
+ pipeline: Sequence = (),
+ cache_dir: Optional[str] = None):
+ self.dataset_name = dataset
+ if Path(dataset).exists():
+ # load local folder
+ data_file = os.path.join(dataset, 'metadata.csv')
+ self.dataset = load_dataset(
+ 'csv', data_files=data_file, cache_dir=cache_dir)['train']
+ else:
+ # load huggingface online
+ self.dataset = load_dataset(dataset, cache_dir=cache_dir)['train']
+ self.pipeline = Compose(pipeline)
+
+ self.image_column = image_column
+ self.condition_column = condition_column
+ self.caption_column = caption_column
+
+ def __len__(self) -> int:
+ """Get the length of dataset.
+
+ Returns:
+ int: The length of filtered dataset.
+ """
+ return len(self.dataset)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Get the idx-th image and data information of dataset after
+ ``self.train_transforms`.
+
+ Args:
+ idx (int): The index of self.data_list.
+
+ Returns:
+ dict: The idx-th image and data information of dataset after
+ ``self.train_transforms``.
+ """
+ data_info = self.dataset[idx]
+ image = data_info[self.image_column]
+ if type(image) == str:
+ image = Image.open(os.path.join(self.dataset_name, image))
+ image = image.convert('RGB')
+
+ condition_image = data_info[self.condition_column]
+ if type(condition_image) == str:
+ condition_image = Image.open(
+ os.path.join(self.dataset_name, condition_image))
+ condition_image = condition_image.convert('RGB')
+
+ caption = data_info[self.caption_column]
+ if isinstance(caption, str):
+ pass
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ caption = random.choice(caption)
+ else:
+ raise ValueError(
+ f'Caption column `{self.caption_column}` should contain either'
+ ' strings or lists of strings.')
+ result = dict(img=image, condition_img=condition_image, text=caption)
+ result = self.pipeline(result)
+
+ return result
diff --git a/diffengine/datasets/hf_datasets.py b/diffengine/datasets/hf_datasets.py
index 968df0a..878da4d 100644
--- a/diffengine/datasets/hf_datasets.py
+++ b/diffengine/datasets/hf_datasets.py
@@ -1,7 +1,7 @@
import os
import random
from pathlib import Path
-from typing import Sequence
+from typing import Optional, Sequence
import numpy as np
from datasets import load_dataset
@@ -21,21 +21,25 @@ class HFDataset(Dataset):
image_column (str): Image column name. Defaults to 'image'.
caption_column (str): Caption column name. Defaults to 'text'.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ cache_dir (str, optional): The directory where the downloaded datasets
+ will be stored.Defaults to None.
"""
def __init__(self,
dataset: str,
image_column: str = 'image',
caption_column: str = 'text',
- pipeline: Sequence = ()):
+ pipeline: Sequence = (),
+ cache_dir: Optional[str] = None):
self.dataset_name = dataset
if Path(dataset).exists():
# load local folder
data_file = os.path.join(dataset, 'metadata.csv')
- self.dataset = load_dataset('csv', data_files=data_file)['train']
+ self.dataset = load_dataset(
+ 'csv', data_files=data_file, cache_dir=cache_dir)['train']
else:
# load huggingface online
- self.dataset = load_dataset(dataset)['train']
+ self.dataset = load_dataset(dataset, cache_dir=cache_dir)['train']
self.pipeline = Compose(pipeline)
self.image_column = image_column
diff --git a/diffengine/datasets/hf_dreambooth_datasets.py b/diffengine/datasets/hf_dreambooth_datasets.py
index c7614a0..dbe34b8 100644
--- a/diffengine/datasets/hf_dreambooth_datasets.py
+++ b/diffengine/datasets/hf_dreambooth_datasets.py
@@ -37,6 +37,8 @@ class HFDreamBoothDataset(Dataset):
class_prompt (Optional[str]): The prompt to specify images in the same
class as provided instance images. Defaults to None.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
+ cache_dir (str, optional): The directory where the downloaded datasets
+ will be stored.Defaults to None.
"""
default_class_image_config: dict = dict(
model='runwayml/stable-diffusion-v1-5',
@@ -56,16 +58,18 @@ def __init__(self,
device='cuda',
),
class_prompt: Optional[str] = None,
- pipeline: Sequence = ()):
+ pipeline: Sequence = (),
+ cache_dir: Optional[str] = None):
if Path(dataset).exists():
# load local folder
data_files = {}
data_files['train'] = '**'
- self.dataset = load_dataset(dataset, data_files)['train']
+ self.dataset = load_dataset(
+ dataset, data_files, cache_dir=cache_dir)['train']
else:
# load huggingface online
- self.dataset = load_dataset(dataset)['train']
+ self.dataset = load_dataset(dataset, cache_dir=cache_dir)['train']
self.pipeline = Compose(pipeline)
self.instance_prompt = instance_prompt
diff --git a/diffengine/datasets/transforms/__init__.py b/diffengine/datasets/transforms/__init__.py
index 7000a2c..e15ff3a 100644
--- a/diffengine/datasets/transforms/__init__.py
+++ b/diffengine/datasets/transforms/__init__.py
@@ -1,11 +1,11 @@
from .base import BaseTransform
+from .dump_image import DumpImage
from .formatting import PackInputs
-from .processing import (TRANSFORMS, CenterCropWithCropPoint, ComputeTimeIds,
- RandomCropWithCropPoint,
- RandomHorizontalFlipFixCropPoint, SaveImageShape)
+from .processing import (TRANSFORMS, CenterCrop, ComputeTimeIds, RandomCrop,
+ RandomHorizontalFlip, SaveImageShape)
__all__ = [
'BaseTransform', 'PackInputs', 'TRANSFORMS', 'SaveImageShape',
- 'RandomCropWithCropPoint', 'CenterCropWithCropPoint',
- 'RandomHorizontalFlipFixCropPoint', 'ComputeTimeIds'
+ 'RandomCrop', 'CenterCrop', 'RandomHorizontalFlip', 'ComputeTimeIds',
+ 'DumpImage'
]
diff --git a/diffengine/datasets/transforms/dump_image.py b/diffengine/datasets/transforms/dump_image.py
new file mode 100644
index 0000000..de583cc
--- /dev/null
+++ b/diffengine/datasets/transforms/dump_image.py
@@ -0,0 +1,59 @@
+from os import path as osp
+
+import cv2
+import mmengine
+import numpy as np
+from torch.multiprocessing import Value
+
+from diffengine.registry import TRANSFORMS
+
+
+@TRANSFORMS.register_module()
+class DumpImage:
+ """Dump the image processed by the pipeline.
+
+ Args:
+ max_imgs (int): Maximum value of output.
+ dump_dir (str): Dump output directory.
+ """
+
+ def __init__(self, max_imgs: int, dump_dir: str):
+ self.max_imgs = max_imgs
+ self.dump_dir = dump_dir
+ mmengine.mkdir_or_exist(self.dump_dir)
+ self.num_dumped_imgs = Value('i', 0)
+
+ def __call__(self, results):
+ """Dump the input image to the specified directory.
+
+ No changes will be
+ made.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ results (dict): Result dict from loading pipeline. (same as input)
+ """
+
+ enable_dump = False
+ with self.num_dumped_imgs.get_lock():
+ if self.num_dumped_imgs.value < self.max_imgs:
+ self.num_dumped_imgs.value += 1
+ enable_dump = True
+ dump_id = self.num_dumped_imgs.value
+
+ if enable_dump:
+ img = results['img']
+ if img.shape[0] in [1, 3]:
+ img = img.permute(1, 2, 0) * 255
+ out_file = osp.join(self.dump_dir, f'{dump_id}_image.png')
+ cv2.imwrite(out_file, img.numpy().astype(np.uint8))
+
+ if 'condition_img':
+ condition_img = results['condition_img']
+ if condition_img.shape[0] in [1, 3]:
+ condition_img = condition_img.permute(1, 2, 0) * 255
+ cond_out_file = osp.join(self.dump_dir, f'{dump_id}_cond.png')
+ cv2.imwrite(cond_out_file,
+ condition_img.numpy().astype(np.uint8))
+
+ return results
diff --git a/diffengine/datasets/transforms/processing.py b/diffengine/datasets/transforms/processing.py
index f3859d3..052ee05 100644
--- a/diffengine/datasets/transforms/processing.py
+++ b/diffengine/datasets/transforms/processing.py
@@ -2,7 +2,7 @@
import random
import re
from enum import EnumMeta
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchvision
from torchvision.transforms.functional import crop
@@ -33,8 +33,19 @@ def _interpolation_modes_from_str(t: str):
class TorchVisonTransformWrapper:
+ """TorchVisonTransformWrapper.
- def __init__(self, transform, *args, **kwargs):
+ We can use torchvision.transforms like `dict(type='torchvision/Resize',
+ size=512)`
+
+ Args:
+ transform (str): The name of transform. For example
+ `torchvision/Resize`.
+ keys (List[str]): `keys` to apply augmentation from results.
+ """
+
+ def __init__(self, transform, *args, keys: List[str] = ['img'], **kwargs):
+ self.keys = keys
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
str):
kwargs['interpolation'] = _interpolation_modes_from_str(
@@ -44,7 +55,8 @@ def __init__(self, transform, *args, **kwargs):
self.t = transform(*args, **kwargs)
def __call__(self, results):
- results['img'] = self.t(results['img'])
+ for k in self.keys:
+ results[k] = self.t(results[k])
return results
def __repr__(self) -> str:
@@ -99,11 +111,30 @@ def transform(self,
@TRANSFORMS.register_module()
-class RandomCropWithCropPoint(BaseTransform):
- """RandomCrop and save crop top left as 'crop_top_left' in results."""
+class RandomCrop(BaseTransform):
+ """RandomCrop.
+
+ The difference from torchvision/RandomCrop is
+ 1. save crop top left as 'crop_top_left' and `crop_bottom_right` in
+ results
+ 2. apply same random parameters to multiple `keys` like ['img',
+ 'condition_img'].
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made. If provided a sequence of length 1, it will be interpreted
+ as (size[0], size[0])
+ keys (List[str]): `keys` to apply augmentation from results.
+ """
- def __init__(self, *args, size, **kwargs):
+ def __init__(self,
+ *args,
+ size: Union[Sequence[int], int],
+ keys: List[str] = ['img'],
+ **kwargs):
self.size = size
+ self.keys = keys
self.pipeline = torchvision.transforms.RandomCrop(
*args, size, **kwargs)
@@ -114,22 +145,42 @@ def transform(self,
results (dict): The result dict.
Returns:
- dict: 'crop_top_left' key is added as crop point.
+ dict: 'crop_top_left' and `crop_bottom_right` key is added as crop
+ point.
"""
+ assert all(results['img'].size == results[k].size for k in self.keys)
y1, x1, h, w = self.pipeline.get_params(results['img'],
(self.size, self.size))
- results['img'] = crop(results['img'], y1, x1, h, w)
+ for k in self.keys:
+ results[k] = crop(results[k], y1, x1, h, w)
results['crop_top_left'] = [y1, x1]
results['crop_bottom_right'] = [y1 + h, x1 + w]
return results
@TRANSFORMS.register_module()
-class CenterCropWithCropPoint(BaseTransform):
- """CenterCrop and save crop top left as 'crop_top_left' in results."""
+class CenterCrop(BaseTransform):
+ """CenterCrop.
+
+ The difference from torchvision/CenterCrop is
+ 1. save crop top left as 'crop_top_left' and `crop_bottom_right` in
+ results
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made. If provided a sequence of length 1, it will be interpreted
+ as (size[0], size[0])
+ keys (List[str]): `keys` to apply augmentation from results.
+ """
- def __init__(self, *args, size, **kwargs):
+ def __init__(self,
+ *args,
+ size: Union[Sequence[int], int],
+ keys: List[str] = ['img'],
+ **kwargs):
self.size = size
+ self.keys = keys
self.pipeline = torchvision.transforms.CenterCrop(
*args, size, **kwargs)
@@ -142,22 +193,36 @@ def transform(self,
Returns:
dict: 'crop_top_left' key is added as crop points.
"""
+ assert all(results['img'].size == results[k].size for k in self.keys)
y1 = max(0, int(round((results['img'].height - self.size) / 2.0)))
x1 = max(0, int(round((results['img'].width - self.size) / 2.0)))
y2 = max(0, int(round((results['img'].height + self.size) / 2.0)))
x2 = max(0, int(round((results['img'].width + self.size) / 2.0)))
- results['img'] = self.pipeline(results['img'])
+ for k in self.keys:
+ results[k] = self.pipeline(results[k])
results['crop_top_left'] = [y1, x1]
results['crop_bottom_right'] = [y2, x2]
return results
@TRANSFORMS.register_module()
-class RandomHorizontalFlipFixCropPoint(BaseTransform):
- """Apply RandomHorizontalFlip and fix 'crop_top_left' in results."""
+class RandomHorizontalFlip(BaseTransform):
+ """RandomHorizontalFlip.
+
+ The difference from torchvision/RandomHorizontalFlip is
+ 1. update 'crop_top_left' and `crop_bottom_right` if exists.
+ 2. apply same random parameters to multiple `keys` like ['img',
+ 'condition_img'].
+
+ Args:
+ p (float): probability of the image being flipped.
+ Default value is 0.5.
+ keys (List[str]): `keys` to apply augmentation from results.
+ """
- def __init__(self, *args, p, **kwargs):
+ def __init__(self, *args, p: float = 0.5, keys=['img'], **kwargs):
self.p = p
+ self.keys = keys
self.pipeline = torchvision.transforms.RandomHorizontalFlip(
*args, p=1.0, **kwargs)
@@ -171,7 +236,10 @@ def transform(self,
dict: 'crop_top_left' key is fixed.
"""
if random.random() < self.p:
- results['img'] = self.pipeline(results['img'])
+ assert all(results['img'].size == results[k].size
+ for k in self.keys)
+ for k in self.keys:
+ results[k] = self.pipeline(results[k])
if 'crop_top_left' in results:
y1 = results['crop_top_left'][0]
x1 = results['img'].width - results['crop_bottom_right'][1]
diff --git a/diffengine/engine/hooks/__init__.py b/diffengine/engine/hooks/__init__.py
index 1106117..50b2213 100644
--- a/diffengine/engine/hooks/__init__.py
+++ b/diffengine/engine/hooks/__init__.py
@@ -1,8 +1,10 @@
+from .controlnet_save_hook import ControlNetSaveHook
from .lora_save_hook import LoRASaveHook
from .sd_checkpoint_hook import SDCheckpointHook
from .unet_ema_hook import UnetEMAHook
from .visualization_hook import VisualizationHook
__all__ = [
- 'VisualizationHook', 'UnetEMAHook', 'SDCheckpointHook', 'LoRASaveHook'
+ 'VisualizationHook', 'UnetEMAHook', 'SDCheckpointHook', 'LoRASaveHook',
+ 'ControlNetSaveHook'
]
diff --git a/diffengine/engine/hooks/controlnet_save_hook.py b/diffengine/engine/hooks/controlnet_save_hook.py
new file mode 100644
index 0000000..0906f21
--- /dev/null
+++ b/diffengine/engine/hooks/controlnet_save_hook.py
@@ -0,0 +1,34 @@
+import os.path as osp
+from collections import OrderedDict
+
+from mmengine.hooks import Hook
+from mmengine.model import is_model_wrapper
+from mmengine.registry import HOOKS
+
+
+@HOOKS.register_module()
+class ControlNetSaveHook(Hook):
+ """Save ControlNet weights with diffusers format and pick up ControlNet
+ weights from checkpoint."""
+ priority = 'VERY_LOW'
+
+ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
+ """
+ Args:
+ runner (Runner): The runner of the training, validation or testing
+ process.
+ checkpoint (dict): Model's checkpoint.
+ """
+ model = runner.model
+ if is_model_wrapper(model):
+ model = model.module
+ ckpt_path = osp.join(runner.work_dir, f'step{runner.iter}')
+ model.controlnet.save_pretrained(osp.join(ckpt_path, 'controlnet'))
+
+ # not save no grad key
+ new_ckpt = OrderedDict()
+ sd_keys = checkpoint['state_dict'].keys()
+ for k in sd_keys:
+ if 'controlnet' in k:
+ new_ckpt[k] = checkpoint['state_dict'][k]
+ checkpoint['state_dict'] = new_ckpt
diff --git a/diffengine/engine/hooks/visualization_hook.py b/diffengine/engine/hooks/visualization_hook.py
index 3a9ac06..7ebbeae 100644
--- a/diffengine/engine/hooks/visualization_hook.py
+++ b/diffengine/engine/hooks/visualization_hook.py
@@ -14,19 +14,33 @@ class VisualizationHook(Hook):
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
+ condition_image (`Optional[List[str]]`):
+ The condition image for ControlNet. Defaults to None.
interval (int): Visualization interval (every k iterations).
Defaults to 1.
by_epoch (bool): Whether to visualize by epoch. Defaults to True.
+ height (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
"""
priority = 'NORMAL'
def __init__(self,
prompt: List[str],
+ condition_image: Optional[List[str]] = None,
interval: int = 1,
- by_epoch: bool = True):
+ by_epoch: bool = True,
+ height: Optional[int] = None,
+ width: Optional[int] = None):
self.prompt = prompt
+ self.condition_image = condition_image
self.interval = interval
self.by_epoch = by_epoch
+ self.height = height
+ self.width = width
def after_train_iter(self,
runner,
@@ -44,7 +58,12 @@ def after_train_iter(self,
model = runner.model
if is_model_wrapper(model):
model = model.module
- images = model.infer(self.prompt)
+ if self.condition_image is None:
+ images = model.infer(self.prompt, self.height, self.width)
+ else:
+ # controlnet
+ images = model.infer(self.prompt, self.condition_image,
+ self.height, self.width)
for i, image in enumerate(images):
runner.visualizer.add_image(
f'image{i}_step', image, step=runner.iter)
@@ -58,7 +77,12 @@ def after_train_epoch(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
- images = model.infer(self.prompt)
+ if self.condition_image is None:
+ images = model.infer(self.prompt, self.height, self.width)
+ else:
+ # controlnet
+ images = model.infer(self.prompt, self.condition_image,
+ self.height, self.width)
for i, image in enumerate(images):
runner.visualizer.add_image(
f'image{i}_step', image, step=runner.epoch)
diff --git a/diffengine/models/editors/__init__.py b/diffengine/models/editors/__init__.py
index 7cd39d8..b37baf3 100644
--- a/diffengine/models/editors/__init__.py
+++ b/diffengine/models/editors/__init__.py
@@ -1,2 +1,4 @@
from .stable_diffusion import * # noqa: F401, F403
+from .stable_diffusion_controlnet import * # noqa: F401, F403
from .stable_diffusion_xl import * # noqa: F401, F403
+from .stable_diffusion_xl_controlnet import * # noqa: F401, F403
diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py
index 4e1b352..d43d8ab 100644
--- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py
+++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py
@@ -24,7 +24,8 @@ class StableDiffusion(BaseModel):
Defaults to 'runwayml/stable-diffusion-v1-5'.
loss (dict): Config of loss. Defaults to
``dict(type='L2Loss', loss_weight=1.0)``.
- lora_config (dict): The LoRA config dict. example. dict(rank=4)
+ lora_config (dict, optional): The LoRA config dict.
+ example. dict(rank=4). Defaults to None.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
prior_loss_weight (float): The weight of prior preservation loss.
@@ -37,7 +38,7 @@ class StableDiffusion(BaseModel):
checkpointing to save memory at the expense of slower backward
pass. Defaults to False.
data_preprocessor (dict, optional): The pre-process config of
- :class:`BaseDataPreprocessor`.
+ :class:`SDDataPreprocessor`.
"""
def __init__(
@@ -57,6 +58,7 @@ def __init__(
self.lora_config = deepcopy(lora_config)
self.finetune_text_encoder = finetune_text_encoder
self.prior_loss_weight = prior_loss_weight
+ self.gradient_checkpointing = gradient_checkpointing
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
@@ -76,10 +78,6 @@ def __init__(
self.unet = UNet2DConditionModel.from_pretrained(
model, subfolder='unet')
self.prepare_model()
- if gradient_checkpointing:
- self.unet.enable_gradient_checkpointing()
- if self.finetune_text_encoder:
- self.text_encoder.gradient_checkpointing_enable()
self.set_lora()
def set_lora(self):
@@ -96,6 +94,11 @@ def prepare_model(self):
Disable gradient for some models.
"""
+ if self.gradient_checkpointing:
+ self.unet.enable_gradient_checkpointing()
+ if self.finetune_text_encoder:
+ self.text_encoder.gradient_checkpointing_enable()
+
self.vae.requires_grad_(False)
print_log('Set VAE untrainable.', 'current')
if not self.finetune_text_encoder:
diff --git a/diffengine/models/editors/stable_diffusion_controlnet/__init__.py b/diffengine/models/editors/stable_diffusion_controlnet/__init__.py
new file mode 100644
index 0000000..ef75521
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_controlnet/__init__.py
@@ -0,0 +1,4 @@
+from .sd_controlnet_data_preprocessor import SDControlNetDataPreprocessor
+from .stable_diffusion_controlnet import StableDiffusionControlNet
+
+__all__ = ['StableDiffusionControlNet', 'SDControlNetDataPreprocessor']
diff --git a/diffengine/models/editors/stable_diffusion_controlnet/sd_controlnet_data_preprocessor.py b/diffengine/models/editors/stable_diffusion_controlnet/sd_controlnet_data_preprocessor.py
new file mode 100644
index 0000000..457fee4
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_controlnet/sd_controlnet_data_preprocessor.py
@@ -0,0 +1,30 @@
+from typing import Union
+
+import torch
+from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor
+
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class SDControlNetDataPreprocessor(BaseDataPreprocessor):
+
+ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
+ """Preprocesses the data into the model input format.
+
+ After the data pre-processing of :meth:`cast_data`, ``forward``
+ will stack the input tensor list to a batch tensor at the first
+ dimension.
+
+ Args:
+ data (dict): Data returned by dataloader
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict or list: Data in the same format as the model input.
+ """
+ assert 'result_class_image' not in data['inputs']
+ data['inputs']['img'] = torch.stack(data['inputs']['img'])
+ data['inputs']['condition_img'] = torch.stack(
+ data['inputs']['condition_img'])
+ return super().forward(data) # type: ignore
diff --git a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py
new file mode 100644
index 0000000..156ed78
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py
@@ -0,0 +1,205 @@
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
+from diffusers.utils import load_image
+from mmengine import print_log
+from PIL import Image
+from torch import nn
+
+from diffengine.models.editors.stable_diffusion import StableDiffusion
+from diffengine.models.losses.snr_l2_loss import SNRL2Loss
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class StableDiffusionControlNet(StableDiffusion):
+ """Stable Diffusion ControlNet.
+
+ Args:
+ controlnet_model (str, optional): Path to pretrained VAE model with
+ better numerical stability. More details:
+ https://github.com/huggingface/diffusers/pull/4038.
+ Defaults to None.
+ lora_config (dict, optional): The LoRA config dict. This should be
+ `None` when training ControlNet. Defaults to None.
+ finetune_text_encoder (bool, optional): Whether to fine-tune text
+ encoder. This should be `False` when training ControlNet.
+ Defaults to False.
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`SDControlNetDataPreprocessor`.
+ """
+
+ def __init__(self,
+ *args,
+ controlnet_model: Optional[str] = None,
+ lora_config: Optional[dict] = None,
+ finetune_text_encoder: bool = False,
+ data_preprocessor: Optional[Union[dict, nn.Module]] = dict(
+ type='SDControlNetDataPreprocessor'),
+ **kwargs):
+ assert lora_config is None, \
+ '`lora_config` should be None when training ControlNet'
+ assert not finetune_text_encoder, \
+ '`finetune_text_encoder` should be False when training ControlNet'
+
+ self.controlnet_model = controlnet_model
+
+ super().__init__(
+ *args,
+ lora_config=lora_config,
+ finetune_text_encoder=finetune_text_encoder,
+ data_preprocessor=data_preprocessor,
+ **kwargs)
+
+ def set_lora(self):
+ """Set LORA for model."""
+ pass
+
+ def prepare_model(self):
+ """Prepare model for training.
+
+ Disable gradient for some models.
+ """
+ if self.controlnet_model is not None:
+ self.controlnet = ControlNetModel.from_pretrained(
+ self.controlnet_model)
+ else:
+ self.controlnet = ControlNetModel.from_unet(self.unet)
+
+ if self.gradient_checkpointing:
+ self.controlnet.enable_gradient_checkpointing()
+ self.unet.enable_gradient_checkpointing()
+
+ self.vae.requires_grad_(False)
+ print_log('Set VAE untrainable.', 'current')
+ self.text_encoder.requires_grad_(False)
+ print_log('Set Text Encoder untrainable.', 'current')
+ self.unet.requires_grad_(False)
+ print_log('Set Unet untrainable.', 'current')
+
+ @torch.no_grad()
+ def infer(self,
+ prompt: List[str],
+ condition_image: List[Union[str, Image.Image]],
+ height: Optional[int] = None,
+ width: Optional[int] = None) -> List[np.ndarray]:
+ """Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`List[str]`):
+ The prompt or prompts to guide the image generation.
+ condition_image (`List[Union[str, Image.Image]]`):
+ The condition image for ControlNet.
+ height (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ """
+ assert len(prompt) == len(condition_image)
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ self.model,
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ unet=self.unet,
+ controlnet=self.controlnet,
+ safety_checker=None,
+ dtype=torch.float16)
+ pipeline.set_progress_bar_config(disable=True)
+ images = []
+ for p, img in zip(prompt, condition_image):
+ if type(img) == str:
+ img = load_image(img)
+ img = img.convert('RGB')
+ image = pipeline(
+ p, img, num_inference_steps=50, height=height,
+ width=width).images[0]
+ images.append(np.array(image))
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: Optional[list] = None,
+ mode: str = 'loss'):
+ assert mode == 'loss'
+ inputs['text'] = self.tokenizer(
+ inputs['text'],
+ max_length=self.tokenizer.model_max_length,
+ padding='max_length',
+ truncation=True,
+ return_tensors='pt').input_ids.to(self.device)
+ num_batches = len(inputs['img'])
+ if 'result_class_image' in inputs:
+ # use prior_loss_weight
+ weight = torch.cat([
+ torch.ones((num_batches // 2, )),
+ torch.ones((num_batches // 2, )) * self.prior_loss_weight
+ ]).to(self.device).float().reshape(-1, 1, 1, 1)
+ else:
+ weight = None
+
+ latents = self.vae.encode(inputs['img']).latent_dist.sample()
+ latents = latents * self.vae.config.scaling_factor
+
+ noise = torch.randn_like(latents)
+
+ if self.enable_noise_offset:
+ noise = noise + self.noise_offset_weight * torch.randn(
+ latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
+
+ num_batches = latents.shape[0]
+ timesteps = torch.randint(
+ 0,
+ self.scheduler.num_train_timesteps, (num_batches, ),
+ device=self.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
+
+ encoder_hidden_states = self.text_encoder(inputs['text'])[0]
+
+ if self.scheduler.config.prediction_type == 'epsilon':
+ gt = noise
+ elif self.scheduler.config.prediction_type == 'v_prediction':
+ gt = self.scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError('Unknown prediction type '
+ f'{self.scheduler.config.prediction_type}')
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=inputs['condition_img'],
+ return_dict=False,
+ )
+
+ model_pred = self.unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample).sample
+
+ loss_dict = dict()
+ # calculate loss in FP32
+ if isinstance(self.loss_module, SNRL2Loss):
+ loss = self.loss_module(
+ model_pred.float(),
+ gt.float(),
+ timesteps,
+ self.scheduler.alphas_cumprod,
+ weight=weight)
+ else:
+ loss = self.loss_module(
+ model_pred.float(), gt.float(), weight=weight)
+ loss_dict['loss'] = loss
+ return loss_dict
diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py
index a212b18..015c1b0 100644
--- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py
+++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py
@@ -42,14 +42,14 @@ class StableDiffusionXL(BaseModel):
Args:
model (str): pretrained model name of stable diffusion xl.
Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
- vae_model (str): Path to pretrained VAE model with better numerical
- stability. More details:
+ vae_model (str, optional): Path to pretrained VAE model with better
+ numerical stability. More details:
https://github.com/huggingface/diffusers/pull/4038.
Defaults to None.
loss (dict): Config of loss. Defaults to
``dict(type='L2Loss', loss_weight=1.0)``.
- lora_config (dict): The LoRA config dict. example. dict(rank=4)
- Defaults to None.
+ lora_config (dict, optional): The LoRA config dict.
+ example. dict(rank=4). Defaults to None.
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
prior_loss_weight (float): The weight of prior preservation loss.
@@ -81,6 +81,7 @@ def __init__(
self.lora_config = deepcopy(lora_config)
self.finetune_text_encoder = finetune_text_encoder
self.prior_loss_weight = prior_loss_weight
+ self.gradient_checkpointing = gradient_checkpointing
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
@@ -112,11 +113,6 @@ def __init__(
self.unet = UNet2DConditionModel.from_pretrained(
model, subfolder='unet')
self.prepare_model()
- if gradient_checkpointing:
- self.unet.enable_gradient_checkpointing()
- if self.finetune_text_encoder:
- self.text_encoder_one.gradient_checkpointing_enable()
- self.text_encoder_two.gradient_checkpointing_enable()
self.set_lora()
def set_lora(self):
@@ -135,6 +131,12 @@ def prepare_model(self):
Disable gradient for some models.
"""
+ if self.gradient_checkpointing:
+ self.unet.enable_gradient_checkpointing()
+ if self.finetune_text_encoder:
+ self.text_encoder_one.gradient_checkpointing_enable()
+ self.text_encoder_two.gradient_checkpointing_enable()
+
self.vae.requires_grad_(False)
print_log('Set VAE untrainable.', 'current')
if not self.finetune_text_encoder:
diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/__init__.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/__init__.py
new file mode 100644
index 0000000..31d941d
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/__init__.py
@@ -0,0 +1,4 @@
+from .sdxl_controlnet_data_preprocessor import SDXLControlNetDataPreprocessor
+from .stable_diffusion_xl_controlnet import StableDiffusionXLControlNet
+
+__all__ = ['SDXLControlNetDataPreprocessor', 'StableDiffusionXLControlNet']
diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/sdxl_controlnet_data_preprocessor.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/sdxl_controlnet_data_preprocessor.py
new file mode 100644
index 0000000..43e2182
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/sdxl_controlnet_data_preprocessor.py
@@ -0,0 +1,31 @@
+from typing import Union
+
+import torch
+from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor
+
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class SDXLControlNetDataPreprocessor(BaseDataPreprocessor):
+
+ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
+ """Preprocesses the data into the model input format.
+
+ After the data pre-processing of :meth:`cast_data`, ``forward``
+ will stack the input tensor list to a batch tensor at the first
+ dimension.
+
+ Args:
+ data (dict): Data returned by dataloader
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict or list: Data in the same format as the model input.
+ """
+ assert 'result_class_image' not in data['inputs']
+ data['inputs']['img'] = torch.stack(data['inputs']['img'])
+ data['inputs']['condition_img'] = torch.stack(
+ data['inputs']['condition_img'])
+ data['inputs']['time_ids'] = torch.stack(data['inputs']['time_ids'])
+ return super().forward(data) # type: ignore
diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py
new file mode 100644
index 0000000..14a4f41
--- /dev/null
+++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py
@@ -0,0 +1,222 @@
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
+from diffusers.utils import load_image
+from mmengine import print_log
+from PIL import Image
+from torch import nn
+
+from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL
+from diffengine.models.losses.snr_l2_loss import SNRL2Loss
+from diffengine.registry import MODELS
+
+
+@MODELS.register_module()
+class StableDiffusionXLControlNet(StableDiffusionXL):
+ """Stable Diffusion XL ControlNet.
+
+ Args:
+ controlnet_model (str, optional): Path to pretrained VAE model with
+ better numerical stability. More details:
+ https://github.com/huggingface/diffusers/pull/4038.
+ Defaults to None.
+ lora_config (dict, optional): The LoRA config dict. This should be
+ `None` when training ControlNet. Defaults to None.
+ finetune_text_encoder (bool, optional): Whether to fine-tune text
+ encoder. This should be `False` when training ControlNet.
+ Defaults to False.
+ data_preprocessor (dict, optional): The pre-process config of
+ :class:`SDControlNetDataPreprocessor`.
+ """
+
+ def __init__(self,
+ *args,
+ controlnet_model: Optional[str] = None,
+ lora_config: Optional[dict] = None,
+ finetune_text_encoder: bool = False,
+ data_preprocessor: Optional[Union[dict, nn.Module]] = dict(
+ type='SDXLControlNetDataPreprocessor'),
+ **kwargs):
+ assert lora_config is None, \
+ '`lora_config` should be None when training ControlNet'
+ assert not finetune_text_encoder, \
+ '`finetune_text_encoder` should be False when training ControlNet'
+
+ self.controlnet_model = controlnet_model
+
+ super().__init__(
+ *args,
+ lora_config=lora_config,
+ finetune_text_encoder=finetune_text_encoder,
+ data_preprocessor=data_preprocessor,
+ **kwargs)
+
+ def set_lora(self):
+ """Set LORA for model."""
+ pass
+
+ def prepare_model(self):
+ """Prepare model for training.
+
+ Disable gradient for some models.
+ """
+ if self.controlnet_model is not None:
+ self.controlnet = ControlNetModel.from_pretrained(
+ self.controlnet_model)
+ else:
+ self.controlnet = ControlNetModel.from_unet(self.unet)
+
+ if self.gradient_checkpointing:
+ self.controlnet.enable_gradient_checkpointing()
+ self.unet.enable_gradient_checkpointing()
+
+ self.vae.requires_grad_(False)
+ print_log('Set VAE untrainable.', 'current')
+ self.text_encoder_one.requires_grad_(False)
+ self.text_encoder_two.requires_grad_(False)
+ print_log('Set Text Encoder untrainable.', 'current')
+ self.unet.requires_grad_(False)
+ print_log('Set Unet untrainable.', 'current')
+
+ @torch.no_grad()
+ def infer(self,
+ prompt: List[str],
+ condition_image: List[Union[str, Image.Image]],
+ height: Optional[int] = None,
+ width: Optional[int] = None) -> List[np.ndarray]:
+ """Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`List[str]`):
+ The prompt or prompts to guide the image generation.
+ condition_image (`List[Union[str, Image.Image]]`):
+ The condition image for ControlNet.
+ height (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to
+ `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ """
+ assert len(prompt) == len(condition_image)
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ self.model,
+ vae=self.vae,
+ text_encoder_one=self.text_encoder_one,
+ text_encoder_two=self.text_encoder_two,
+ tokenizer_one=self.tokenizer_one,
+ tokenizer_two=self.tokenizer_two,
+ unet=self.unet,
+ controlnet=self.controlnet,
+ safety_checker=None,
+ dtype=torch.float16,
+ )
+ pipeline.to(self.device)
+ pipeline.set_progress_bar_config(disable=True)
+ images = []
+ for p, img in zip(prompt, condition_image):
+ if type(img) == str:
+ img = load_image(img)
+ img = img.convert('RGB')
+ image = pipeline(
+ p, p, img, num_inference_steps=50, height=height,
+ width=width).images[0]
+ images.append(np.array(image))
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+ def forward(self,
+ inputs: torch.Tensor,
+ data_samples: Optional[list] = None,
+ mode: str = 'loss'):
+ assert mode == 'loss'
+ inputs['text_one'] = self.tokenizer_one(
+ inputs['text'],
+ max_length=self.tokenizer_one.model_max_length,
+ padding='max_length',
+ truncation=True,
+ return_tensors='pt').input_ids.to(self.device)
+ inputs['text_two'] = self.tokenizer_two(
+ inputs['text'],
+ max_length=self.tokenizer_two.model_max_length,
+ padding='max_length',
+ truncation=True,
+ return_tensors='pt').input_ids.to(self.device)
+ num_batches = len(inputs['img'])
+ if 'result_class_image' in inputs:
+ # use prior_loss_weight
+ weight = torch.cat([
+ torch.ones((num_batches // 2, )),
+ torch.ones((num_batches // 2, )) * self.prior_loss_weight
+ ]).float().reshape(-1, 1, 1, 1)
+ else:
+ weight = None
+
+ latents = self.vae.encode(inputs['img']).latent_dist.sample()
+ latents = latents * self.vae.config.scaling_factor
+
+ noise = torch.randn_like(latents)
+
+ if self.enable_noise_offset:
+ noise = noise + self.noise_offset_weight * torch.randn(
+ latents.shape[0], latents.shape[1], 1, 1, device=noise.device)
+
+ timesteps = torch.randint(
+ 0,
+ self.scheduler.num_train_timesteps, (num_batches, ),
+ device=self.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
+
+ prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
+ inputs['text_one'], inputs['text_two'])
+ unet_added_conditions = {
+ 'time_ids': inputs['time_ids'],
+ 'text_embeds': pooled_prompt_embeds
+ }
+
+ if self.scheduler.config.prediction_type == 'epsilon':
+ gt = noise
+ elif self.scheduler.config.prediction_type == 'v_prediction':
+ gt = self.scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError('Unknown prediction type '
+ f'{self.scheduler.config.prediction_type}')
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ noisy_latents,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ controlnet_cond=inputs['condition_img'],
+ return_dict=False,
+ )
+
+ model_pred = self.unet(
+ noisy_latents,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample).sample
+
+ loss_dict = dict()
+ # calculate loss in FP32
+ if isinstance(self.loss_module, SNRL2Loss):
+ loss = self.loss_module(
+ model_pred.float(),
+ gt.float(),
+ timesteps,
+ self.scheduler.alphas_cumprod,
+ weight=weight)
+ else:
+ loss = self.loss_module(
+ model_pred.float(), gt.float(), weight=weight)
+ loss_dict['loss'] = loss
+ return loss_dict
diff --git a/docs/source/index.rst b/docs/source/index.rst
index e03b6c9..b8f5413 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -24,6 +24,11 @@ Welcome to diffengine's documentation!
run_guides/run_sd.md
run_guides/run_sdxl.md
run_guides/run_dreambooth.md
+ run_guides/run_dreambooth_xl.md
+ run_guides/run_lora.md
+ run_guides/run_lora_xl.md
+ run_guides/run_controlnet.md
+ run_guides/run_controlnet_xl.md
run_guides/inference.md
.. toctree::
diff --git a/docs/source/run_guides/run_controlnet.md b/docs/source/run_guides/run_controlnet.md
new file mode 100644
index 0000000..4695c03
--- /dev/null
+++ b/docs/source/run_guides/run_controlnet.md
@@ -0,0 +1,94 @@
+# Stable Diffusion ControlNet Training
+
+You can also check [`configs/stable_diffusion_controlnet/README.md`](../../../configs/stable_diffusion_controlnet/README.md) file.
+
+## Configs
+
+All configuration files are placed under the [`configs/stable_diffusion_controlnet`](../../../configs/stable_diffusion_controlnet/) folder.
+
+Following is the example config fixed from the stable_diffusion_v15_controlnet_fill50k config file in [`configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py`](../../../configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py):
+
+```
+_base_ = [
+ '../_base_/models/stable_diffusion_v15_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
+```
+
+#### Finetuning with Min-SNR Weighting Strategy
+
+The script also allows you to finetune with [Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556).
+
+```
+_base_ = [
+ '../_base_/models/stable_diffusion_v15_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
+
+model = dict(loss=dict(type='SNRL2Loss', snr_gamma=5.0, loss_weight=1.0)) # setup Min-SNR Weighting Strategy
+```
+
+## Run training
+
+Run train
+
+```
+# single gpu
+$ docker compose exec diffengine mim train diffengine ${CONFIG_FILE}
+# Example
+$ docker compose exec diffengine mim train diffengine configs/stable_diffusion_controlnet/stable_diffusion_v15_controlnet_fill50k.py
+
+# multi gpus
+$ docker compose exec diffengine mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
+```
+
+## Inference with diffusers
+
+Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
+
+```py
+import torch
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image
+
+checkpoint = 'work_dirs/stable_diffusion_v15_controlnet_fill50k/step6250'
+prompt = 'cyan circle with brown floral background'
+condition_image = load_image(
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
+)
+
+controlnet = ControlNetModel.from_pretrained(
+ checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ 'runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=torch.float16)
+pipe.to('cuda')
+
+image = pipe(
+ prompt,
+ condition_image,
+ num_inference_steps=50,
+).images[0]
+image.save('demo.png')
+```
+
+We also provide inference demo scripts:
+
+```
+$ mim run diffengine demo_controlnet ${PROMPT} ${CONDITION_IMAGE} ${CHECKPOINT}
+# Example
+$ mim run diffengine demo_controlnet "cyan circle with brown floral background" https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg work_dirs/stable_diffusion_v15_controlnet_fill50k/step6250
+```
+
+## Results Example
+
+#### stable_diffusion_v15_controlnet_fill50k
+
+![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)
+
+![example1](https://github.com/okotaku/diffengine/assets/24734142/a14cc9a6-3a40-4577-bd5a-2ddbab60970d)
+
+You can check [`configs/stable_diffusion_controlnet/README.md`](../../../configs/stable_diffusion_controlnet/README.md#results-example) for more deitals.
diff --git a/docs/source/run_guides/run_controlnet_xl.md b/docs/source/run_guides/run_controlnet_xl.md
new file mode 100644
index 0000000..b610672
--- /dev/null
+++ b/docs/source/run_guides/run_controlnet_xl.md
@@ -0,0 +1,109 @@
+# Stable Diffusion XL ControlNet Training
+
+You can also check [`configs/stable_diffusion_xl_controlnet/README.md`](../../../configs/stable_diffusion_xl_controlnet/README.md) file.
+
+## Configs
+
+All configuration files are placed under the [`configs/stable_diffusion_xl_controlnet`](../../../configs/stable_diffusion_xl_controlnet/) folder.
+
+Following is the example config fixed from the stable_diffusion_xl_controlnet_fill50k config file in [`configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py`](../../../configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py):
+
+```
+_base_ = [
+ '../_base_/models/stable_diffusion_xl_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet_xl.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
+
+optim_wrapper = dict(
+ optimizer=dict(lr=1e-5),
+ accumulative_counts=2,
+ )
+```
+
+#### Finetuning with Min-SNR Weighting Strategy
+
+The script also allows you to finetune with [Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556).
+
+```
+_base_ = [
+ '../_base_/models/stable_diffusion_xl_controlnet.py',
+ '../_base_/datasets/fill50k_controlnet_xl.py',
+ '../_base_/schedules/stable_diffusion_1e.py',
+ '../_base_/default_runtime.py'
+]
+
+optim_wrapper = dict(
+ optimizer=dict(lr=1e-5),
+ accumulative_counts=2,
+ )
+
+model = dict(loss=dict(type='SNRL2Loss', snr_gamma=5.0, loss_weight=1.0)) # setup Min-SNR Weighting Strategy
+```
+
+## Run training
+
+Run train
+
+```
+# single gpu
+$ docker compose exec diffengine mim train diffengine ${CONFIG_FILE}
+# Example
+$ docker compose exec diffengine mim train diffengine configs/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet_fill50k.py
+
+# multi gpus
+$ docker compose exec diffengine mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch
+```
+
+## Inference with diffusers
+
+Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.
+
+```py
+import torch
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
+from diffusers.utils import load_image
+
+checkpoint = 'work_dirs/stable_diffusion_xl_controlnet_fill50k/step25000'
+prompt = 'cyan circle with brown floral background'
+condition_image = load_image(
+ 'https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg'
+)
+
+controlnet = ControlNetModel.from_pretrained(
+ checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
+
+vae = AutoencoderKL.from_pretrained(
+ 'madebyollin/sdxl-vae-fp16-fix',
+ torch_dtype=torch.float16,
+)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ 'stabilityai/stable-diffusion-xl-base-1.0', controlnet=controlnet, vae=vae, torch_dtype=torch.float16)
+pipe.to('cuda')
+
+image = pipe(
+ prompt,
+ condition_image,
+ num_inference_steps=50,
+).images[0]
+image.save('demo.png')
+```
+
+We also provide inference demo scripts, it can run with `--use_sdxl`:
+
+```
+$ mim run diffengine demo_controlnet ${PROMPT} ${CONDITION_IMAGE} ${CHECKPOINT} --sdmodel stabilityai/stable-diffusion-xl-base-1.0 --vaemodel madebyollin/sdxl-vae-fp16-fix --use_sdxl
+# Example
+$ mim run diffengine demo_controlnet "cyan circle with brown floral background" https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg work_dirs/stable_diffusion_xl_controlnet_fill50k/step25000 --sdmodel stabilityai/stable-diffusion-xl-base-1.0 --vaemodel madebyollin/sdxl-vae-fp16-fix --use_sdxl
+```
+
+## Results Example
+
+#### stable_diffusion_xl_controlnet_fill50k
+
+![input1](https://datasets-server.huggingface.co/assets/fusing/fill50k/--/default/train/74/conditioning_image/image.jpg)
+
+![example1](https://github.com/okotaku/diffengine/assets/24734142/a331a413-a9e7-4b9a-aa75-72279c4cc77a)
+
+You can check [`configs/stable_diffusion_xl_controlnet/README.md`](../../../configs/stable_diffusion_xl_controlnet/README.md#results-example) for more deitals.
diff --git a/docs/source/user_guides/config.md b/docs/source/user_guides/config.md
index 0558528..784a8b9 100644
--- a/docs/source/user_guides/config.md
+++ b/docs/source/user_guides/config.md
@@ -69,8 +69,8 @@ Following is the data primitive config of the stable_diffusion_v15 config in [`c
```python
train_pipeline = [ # augmentation settings
dict(type='torchvision/Resize', size=512, interpolation='bilinear'),
- dict(type='torchvision/RandomCrop', size=512),
- dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+ dict(type='RandomCrop', size=512),
+ dict(type='RandomHorizontalFlip', p=0.5),
dict(type='torchvision/ToTensor'),
dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),
dict(type='PackInputs'),
diff --git a/examples/example-dreambooth.ipynb b/examples/example-dreambooth.ipynb
index 3738bdc..cb44335 100644
--- a/examples/example-dreambooth.ipynb
+++ b/examples/example-dreambooth.ipynb
@@ -77,8 +77,8 @@
"\n",
"train_pipeline = [ # fix image size\n",
" dict(type='torchvision/Resize', size=768, interpolation='bilinear'),\n",
- " dict(type='torchvision/RandomCrop', size=768),\n",
- " dict(type='torchvision/RandomHorizontalFlip', p=0.5),\n",
+ " dict(type='RandomCrop', size=768),\n",
+ " dict(type='RandomHorizontalFlip', p=0.5),\n",
" dict(type='torchvision/ToTensor'),\n",
" dict(type='torchvision/Normalize', mean=[0.5], std=[0.5]),\n",
" dict(type='PackInputs'),\n",
diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py
index db83565..9e63b50 100644
--- a/tests/test_datasets/test_transforms/test_processing.py
+++ b/tests/test_datasets/test_transforms/test_processing.py
@@ -69,7 +69,7 @@ def test_transform(self):
pipeline_cfg = [
dict(type='torchvision/Resize', size=176),
- dict(type='torchvision/RandomHorizontalFlip'),
+ dict(type='RandomHorizontalFlip'),
dict(type='torchvision/PILToTensor'),
dict(type='torchvision/ConvertImageDtype', dtype='float'),
dict(
@@ -117,19 +117,43 @@ def test_transform(self):
[32, 32, 0, 0, img.height, img.width])
-class TestRandomCropWithCropPoint(TestCase):
+class TestRandomCrop(TestCase):
crop_size = 32
def test_register(self):
- self.assertIn('RandomCropWithCropPoint', TRANSFORMS)
+ self.assertIn('RandomCrop', TRANSFORMS)
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
data = {'img': Image.open(img_path)}
+ # test transform
+ trans = TRANSFORMS.build(dict(type='RandomCrop', size=self.crop_size))
+ data = trans(data)
+ self.assertIn('crop_top_left', data)
+ assert len(data['crop_top_left']) == 2
+ assert data['img'].height == data['img'].width == self.crop_size
+ upper, left = data['crop_top_left']
+ lower, right = data['crop_bottom_right']
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data['img']),
+ np.array(Image.open(img_path).crop((left, upper, right, lower))))
+
+ def test_transform_multiple_keys(self):
+ img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
+ data = {
+ 'img': Image.open(img_path),
+ 'condition_img': Image.open(img_path)
+ }
+
# test transform
trans = TRANSFORMS.build(
- dict(type='RandomCropWithCropPoint', size=self.crop_size))
+ dict(
+ type='RandomCrop',
+ size=self.crop_size,
+ keys=['img', 'condition_img']))
data = trans(data)
self.assertIn('crop_top_left', data)
assert len(data['crop_top_left']) == 2
@@ -141,21 +165,46 @@ def test_transform(self):
np.equal(
np.array(data['img']),
np.array(Image.open(img_path).crop((left, upper, right, lower))))
+ np.equal(np.array(data['img']), np.array(data['condition_img']))
-class TestCenterCropWithCropPoint(TestCase):
+class TestCenterCrop(TestCase):
crop_size = 32
def test_register(self):
- self.assertIn('CenterCropWithCropPoint', TRANSFORMS)
+ self.assertIn('CenterCrop', TRANSFORMS)
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
data = {'img': Image.open(img_path)}
+ # test transform
+ trans = TRANSFORMS.build(dict(type='CenterCrop', size=self.crop_size))
+ data = trans(data)
+ self.assertIn('crop_top_left', data)
+ assert len(data['crop_top_left']) == 2
+ assert data['img'].height == data['img'].width == self.crop_size
+ upper, left = data['crop_top_left']
+ lower, right = data['crop_bottom_right']
+ assert lower == upper + self.crop_size
+ assert right == left + self.crop_size
+ np.equal(
+ np.array(data['img']),
+ np.array(Image.open(img_path).crop((left, upper, right, lower))))
+
+ def test_transform_multiple_keys(self):
+ img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
+ data = {
+ 'img': Image.open(img_path),
+ 'condition_img': Image.open(img_path)
+ }
+
# test transform
trans = TRANSFORMS.build(
- dict(type='CenterCropWithCropPoint', size=self.crop_size))
+ dict(
+ type='CenterCrop',
+ size=self.crop_size,
+ keys=['img', 'condition_img']))
data = trans(data)
self.assertIn('crop_top_left', data)
assert len(data['crop_top_left']) == 2
@@ -167,12 +216,13 @@ def test_transform(self):
np.equal(
np.array(data['img']),
np.array(Image.open(img_path).crop((left, upper, right, lower))))
+ np.equal(np.array(data['img']), np.array(data['condition_img']))
-class TestRandomHorizontalFlipFixCropPoint(TestCase):
+class TestRandomHorizontalFlip(TestCase):
def test_register(self):
- self.assertIn('RandomHorizontalFlipFixCropPoint', TRANSFORMS)
+ self.assertIn('RandomHorizontalFlip', TRANSFORMS)
def test_transform(self):
img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
@@ -183,8 +233,7 @@ def test_transform(self):
}
# test transform
- trans = TRANSFORMS.build(
- dict(type='RandomHorizontalFlipFixCropPoint', p=1.))
+ trans = TRANSFORMS.build(dict(type='RandomHorizontalFlip', p=1.))
data = trans(data)
self.assertIn('crop_top_left', data)
assert len(data['crop_top_left']) == 2
@@ -201,10 +250,35 @@ def test_transform(self):
'crop_top_left': [0, 0],
'crop_bottom_right': [10, 10]
}
- trans = TRANSFORMS.build(
- dict(type='RandomHorizontalFlipFixCropPoint', p=0.))
+ trans = TRANSFORMS.build(dict(type='RandomHorizontalFlip', p=0.))
data = trans(data)
self.assertIn('crop_top_left', data)
self.assertListEqual(data['crop_top_left'], [0, 0])
np.equal(np.array(data['img']), np.array(Image.open(img_path)))
+
+ def test_transform_multiple_keys(self):
+ img_path = osp.join(osp.dirname(__file__), '../../testdata/color.jpg')
+ data = {
+ 'img': Image.open(img_path),
+ 'condition_img': Image.open(img_path),
+ 'crop_top_left': [0, 0],
+ 'crop_bottom_right': [10, 10]
+ }
+
+ # test transform
+ trans = TRANSFORMS.build(
+ dict(
+ type='RandomHorizontalFlip',
+ p=1.,
+ keys=['img', 'condition_img']))
+ data = trans(data)
+ self.assertIn('crop_top_left', data)
+ assert len(data['crop_top_left']) == 2
+ self.assertListEqual(data['crop_top_left'],
+ [0, data['img'].width - 10])
+
+ np.equal(
+ np.array(data['img']),
+ np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT)))
+ np.equal(np.array(data['img']), np.array(data['condition_img']))
diff --git a/tests/test_engine/test_hooks/test_controlnet_save_hook.py b/tests/test_engine/test_hooks/test_controlnet_save_hook.py
new file mode 100644
index 0000000..0e9ba62
--- /dev/null
+++ b/tests/test_engine/test_hooks/test_controlnet_save_hook.py
@@ -0,0 +1,68 @@
+import copy
+import os.path as osp
+import shutil
+from pathlib import Path
+
+import torch.nn as nn
+from mmengine.model import BaseModel
+from mmengine.registry import MODELS
+from mmengine.testing import RunnerTestCase
+
+from diffengine.engine.hooks import ControlNetSaveHook
+from diffengine.models.editors import (SDControlNetDataPreprocessor,
+ StableDiffusionControlNet)
+from diffengine.models.losses import L2Loss
+
+
+class DummyWrapper(BaseModel):
+
+ def __init__(self, model):
+ super().__init__()
+ if not isinstance(model, nn.Module):
+ model = MODELS.build(model)
+ self.module = model
+
+ def forward(self, *args, **kwargs):
+ return self.module(*args, **kwargs)
+
+
+class TestLoRASaveHook(RunnerTestCase):
+
+ def setUp(self) -> None:
+ MODELS.register_module(name='DummyWrapper', module=DummyWrapper)
+ MODELS.register_module(
+ name='StableDiffusionControlNet', module=StableDiffusionControlNet)
+ MODELS.register_module(
+ name='SDControlNetDataPreprocessor',
+ module=SDControlNetDataPreprocessor)
+ MODELS.register_module(name='L2Loss', module=L2Loss)
+ return super().setUp()
+
+ def tearDown(self):
+ MODELS.module_dict.pop('DummyWrapper')
+ MODELS.module_dict.pop('StableDiffusionControlNet')
+ MODELS.module_dict.pop('SDControlNetDataPreprocessor')
+ MODELS.module_dict.pop('L2Loss')
+ return super().tearDown()
+
+ def test_init(self):
+ ControlNetSaveHook()
+
+ def test_before_save_checkpoint(self):
+ cfg = copy.deepcopy(self.epoch_based_cfg)
+ cfg.model.type = 'StableDiffusionControlNet'
+ cfg.model.model = 'diffusers/tiny-stable-diffusion-torch'
+ runner = self.build_runner(cfg)
+ checkpoint = dict(
+ state_dict=StableDiffusionControlNet(
+ model='diffusers/tiny-stable-diffusion-torch').state_dict())
+ hook = ControlNetSaveHook()
+ hook.before_save_checkpoint(runner, checkpoint)
+
+ assert Path(
+ osp.join(runner.work_dir, f'step{runner.iter}', 'controlnet',
+ 'diffusion_pytorch_model.bin')).exists
+ shutil.rmtree(osp.join(runner.work_dir), ignore_errors=True)
+
+ for key in checkpoint['state_dict'].keys():
+ assert key.startswith('controlnet')
diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py
index bec1f40..1b1c26b 100644
--- a/tests/test_engine/test_hooks/test_visualization_hook.py
+++ b/tests/test_engine/test_hooks/test_visualization_hook.py
@@ -6,7 +6,9 @@
from mmengine.testing import RunnerTestCase
from diffengine.engine.hooks import VisualizationHook
-from diffengine.models.editors import SDDataPreprocessor, StableDiffusion
+from diffengine.models.editors import (SDControlNetDataPreprocessor,
+ SDDataPreprocessor, StableDiffusion,
+ StableDiffusionControlNet)
from diffengine.models.losses import L2Loss
@@ -16,12 +18,19 @@ def setUp(self) -> None:
MODELS.register_module(name='StableDiffusion', module=StableDiffusion)
MODELS.register_module(
name='SDDataPreprocessor', module=SDDataPreprocessor)
+ MODELS.register_module(
+ name='StableDiffusionControlNet', module=StableDiffusionControlNet)
+ MODELS.register_module(
+ name='SDControlNetDataPreprocessor',
+ module=SDControlNetDataPreprocessor)
MODELS.register_module(name='L2Loss', module=L2Loss)
return super().setUp()
def tearDown(self):
MODELS.module_dict.pop('StableDiffusion')
MODELS.module_dict.pop('SDDataPreprocessor')
+ MODELS.module_dict.pop('StableDiffusionControlNet')
+ MODELS.module_dict.pop('SDControlNetDataPreprocessor')
MODELS.module_dict.pop('L2Loss')
return super().tearDown()
@@ -34,6 +43,16 @@ def test_after_train_epoch(self):
hook = VisualizationHook(prompt=['a dog'])
hook.after_train_epoch(runner)
+ def test_after_train_epoch_with_condition(self):
+ runner = MagicMock()
+
+ # test epoch-based
+ runner.train_loop = MagicMock(spec=EpochBasedTrainLoop)
+ runner.epoch = 5
+ hook = VisualizationHook(
+ prompt=['a dog'], condition_image=['testdata/color.jpg'])
+ hook.after_train_epoch(runner)
+
def test_after_train_iter(self):
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.train_cfg.max_iters = 100
@@ -44,3 +63,20 @@ def test_after_train_iter(self):
for i in range(10):
hook.after_train_iter(runner, i)
runner.train_loop._iter += 1
+
+ def test_after_train_iter_with_condition(self):
+ cfg = copy.deepcopy(self.iter_based_cfg)
+ cfg.train_cfg.max_iters = 100
+ cfg.model.type = 'StableDiffusionControlNet'
+ cfg.model.model = 'hf-internal-testing/tiny-stable-diffusion-pipe'
+ cfg.model.controlnet_model = 'hf-internal-testing/tiny-controlnet'
+ runner = self.build_runner(cfg)
+ hook = VisualizationHook(
+ prompt=['a dog'],
+ condition_image=['tests/testdata/cond.jpg'],
+ height=64,
+ width=64,
+ by_epoch=False)
+ for i in range(10):
+ hook.after_train_iter(runner, i)
+ runner.train_loop._iter += 1
diff --git a/tests/test_models/test_editors/test_stable_diffusion_controlnet/test_stable_diffusion_controlnet.py b/tests/test_models/test_editors/test_stable_diffusion_controlnet/test_stable_diffusion_controlnet.py
new file mode 100644
index 0000000..687dd79
--- /dev/null
+++ b/tests/test_models/test_editors/test_stable_diffusion_controlnet/test_stable_diffusion_controlnet.py
@@ -0,0 +1,103 @@
+from unittest import TestCase
+
+import torch
+from mmengine.optim import OptimWrapper
+from torch.optim import SGD
+
+from diffengine.models.editors import (SDControlNetDataPreprocessor,
+ StableDiffusionControlNet)
+from diffengine.models.losses import L2Loss
+
+
+class TestStableDiffusionControlNet(TestCase):
+
+ def test_init(self):
+ with self.assertRaisesRegex(AssertionError,
+ '`lora_config` should be None'):
+ _ = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ data_preprocessor=SDControlNetDataPreprocessor(),
+ lora_config=dict(rank=4))
+
+ with self.assertRaisesRegex(AssertionError,
+ '`finetune_text_encoder` should be False'):
+ _ = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ data_preprocessor=SDControlNetDataPreprocessor(),
+ finetune_text_encoder=True)
+
+ def test_infer(self):
+ StableDiffuser = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ data_preprocessor=SDControlNetDataPreprocessor())
+
+ # test infer
+ result = StableDiffuser.infer(
+ ['an insect robot preparing a delicious meal'],
+ ['tests/testdata/color.jpg'],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ # test device
+ assert StableDiffuser.device.type == 'cpu'
+
+ def test_train_step(self):
+ # test load with loss module
+ StableDiffuser = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ loss=L2Loss(),
+ data_preprocessor=SDControlNetDataPreprocessor())
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((3, 64, 64))],
+ text=['a dog'],
+ condition_img=[torch.zeros((3, 64, 64))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ self.assertIsInstance(log_vars['loss'], torch.Tensor)
+
+ def test_train_step_with_gradient_checkpointing(self):
+ # test load with loss module
+ StableDiffuser = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ loss=L2Loss(),
+ data_preprocessor=SDControlNetDataPreprocessor(),
+ gradient_checkpointing=True)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((3, 64, 64))],
+ text=['a dog'],
+ condition_img=[torch.zeros((3, 64, 64))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ self.assertIsInstance(log_vars['loss'], torch.Tensor)
+
+ def test_val_and_test_step(self):
+ StableDiffuser = StableDiffusionControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet',
+ loss=L2Loss(),
+ data_preprocessor=SDControlNetDataPreprocessor())
+
+ # test val_step
+ with self.assertRaisesRegex(NotImplementedError, 'val_step is not'):
+ StableDiffuser.val_step(torch.zeros((1, )))
+
+ # test test_step
+ with self.assertRaisesRegex(NotImplementedError, 'test_step is not'):
+ StableDiffuser.test_step(torch.zeros((1, )))
diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnet.py b/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnet.py
new file mode 100644
index 0000000..7144e1f
--- /dev/null
+++ b/tests/test_models/test_editors/test_stable_diffusion_xl_controlnet/test_stable_diffusion_xl_controlnet.py
@@ -0,0 +1,105 @@
+from unittest import TestCase
+
+import torch
+from mmengine.optim import OptimWrapper
+from torch.optim import SGD
+
+from diffengine.models.editors import (SDXLControlNetDataPreprocessor,
+ StableDiffusionXLControlNet)
+from diffengine.models.losses import L2Loss
+
+
+class TestStableDiffusionXLControlNet(TestCase):
+
+ def test_init(self):
+ with self.assertRaisesRegex(AssertionError,
+ '`lora_config` should be None'):
+ _ = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ data_preprocessor=SDXLControlNetDataPreprocessor(),
+ lora_config=dict(rank=4))
+
+ with self.assertRaisesRegex(AssertionError,
+ '`finetune_text_encoder` should be False'):
+ _ = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ data_preprocessor=SDXLControlNetDataPreprocessor(),
+ finetune_text_encoder=True)
+
+ def test_infer(self):
+ StableDiffuser = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ data_preprocessor=SDXLControlNetDataPreprocessor())
+
+ # test infer
+ result = StableDiffuser.infer(
+ ['an insect robot preparing a delicious meal'],
+ ['tests/testdata/color.jpg'],
+ height=64,
+ width=64)
+ assert len(result) == 1
+ assert result[0].shape == (64, 64, 3)
+
+ # test device
+ assert StableDiffuser.device.type == 'cpu'
+
+ def test_train_step(self):
+ # test load with loss module
+ StableDiffuser = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ loss=L2Loss(),
+ data_preprocessor=SDXLControlNetDataPreprocessor())
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((3, 64, 64))],
+ text=['a dog'],
+ time_ids=[torch.zeros((1, 6))],
+ condition_img=[torch.zeros((3, 64, 64))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ self.assertIsInstance(log_vars['loss'], torch.Tensor)
+
+ def test_train_step_with_gradient_checkpointing(self):
+ # test load with loss module
+ StableDiffuser = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ loss=L2Loss(),
+ data_preprocessor=SDXLControlNetDataPreprocessor(),
+ gradient_checkpointing=True)
+
+ # test train step
+ data = dict(
+ inputs=dict(
+ img=[torch.zeros((3, 64, 64))],
+ text=['a dog'],
+ time_ids=[torch.zeros((1, 6))],
+ condition_img=[torch.zeros((3, 64, 64))]))
+ optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
+ optim_wrapper = OptimWrapper(optimizer)
+ log_vars = StableDiffuser.train_step(data, optim_wrapper)
+ assert log_vars
+ self.assertIsInstance(log_vars['loss'], torch.Tensor)
+
+ def test_val_and_test_step(self):
+ StableDiffuser = StableDiffusionXLControlNet(
+ 'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
+ controlnet_model='hf-internal-testing/tiny-controlnet-sdxl',
+ loss=L2Loss(),
+ data_preprocessor=SDXLControlNetDataPreprocessor())
+
+ # test val_step
+ with self.assertRaisesRegex(NotImplementedError, 'val_step is not'):
+ StableDiffuser.val_step(torch.zeros((1, )))
+
+ # test test_step
+ with self.assertRaisesRegex(NotImplementedError, 'test_step is not'):
+ StableDiffuser.test_step(torch.zeros((1, )))
diff --git a/tests/testdata/cond.jpg b/tests/testdata/cond.jpg
new file mode 100644
index 0000000..7f61e05
Binary files /dev/null and b/tests/testdata/cond.jpg differ
diff --git a/tools/demo.py b/tools/demo.py
index 1b5ffbd..2350563 100644
--- a/tools/demo.py
+++ b/tools/demo.py
@@ -8,7 +8,7 @@
def main():
parser = ArgumentParser()
parser.add_argument('prompt', help='Prompt text')
- parser.add_argument('checkpoint', help='Prompt text')
+ parser.add_argument('checkpoint', help='Path to checkpoint.')
parser.add_argument(
'--sdmodel',
help='Stable Diffusion model name',
diff --git a/tools/demo_controlnet.py b/tools/demo_controlnet.py
new file mode 100644
index 0000000..c4cb126
--- /dev/null
+++ b/tools/demo_controlnet.py
@@ -0,0 +1,78 @@
+from argparse import ArgumentParser
+
+import torch
+from diffusers import (AutoencoderKL, ControlNetModel,
+ StableDiffusionControlNetPipeline,
+ StableDiffusionXLControlNetPipeline)
+from diffusers.utils import load_image
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('prompt', help='Prompt text.')
+ parser.add_argument('condition_image', help='Path to condition image.')
+ parser.add_argument('checkpoint', help='Path to controlnet weight.')
+ parser.add_argument(
+ '--sdmodel',
+ help='Stable Diffusion model name',
+ default='runwayml/stable-diffusion-v1-5')
+ parser.add_argument(
+ '--vaemodel',
+ type=str,
+ default=None,
+ help='Path to pretrained VAE model with better numerical stability. '
+ 'More details: https://github.com/huggingface/diffusers/pull/4038.',
+ )
+ parser.add_argument(
+ '--use_sdxl',
+ action='store_true',
+ help='Whether to use SDXL as base model.')
+ parser.add_argument('--out', help='Output path', default='demo.png')
+ parser.add_argument(
+ '--device', help='Device used for inference', default='cuda')
+ args = parser.parse_args()
+
+ controlnet = ControlNetModel.from_pretrained(
+ args.checkpoint, subfolder='controlnet', torch_dtype=torch.float16)
+ if args.use_sdxl:
+ controlnet_cls = StableDiffusionXLControlNetPipeline
+ else:
+ controlnet_cls = StableDiffusionControlNetPipeline
+ print(controlnet_cls)
+ if args.vaemodel is not None:
+ vae = AutoencoderKL.from_pretrained(
+ args.vaemodel,
+ torch_dtype=torch.float16,
+ )
+ pipe = controlnet_cls.from_pretrained(
+ args.sdmodel,
+ controlnet=controlnet,
+ vae=vae,
+ torch_dtype=torch.float16,
+ safety_checker=None)
+ else:
+ pipe = controlnet_cls.from_pretrained(
+ args.sdmodel,
+ controlnet=controlnet,
+ torch_dtype=torch.float16,
+ safety_checker=None)
+ pipe.to(args.device)
+
+ if args.use_sdxl:
+ image = pipe(
+ args.prompt,
+ args.prompt,
+ load_image(args.condition_image),
+ num_inference_steps=50,
+ ).images[0]
+ else:
+ image = pipe(
+ args.prompt,
+ load_image(args.condition_image),
+ num_inference_steps=50,
+ ).images[0]
+ image.save(args.out)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/demo_lora.py b/tools/demo_lora.py
index a285105..6b1db3b 100644
--- a/tools/demo_lora.py
+++ b/tools/demo_lora.py
@@ -6,8 +6,8 @@
def main():
parser = ArgumentParser()
- parser.add_argument('prompt', help='Prompt text')
- parser.add_argument('checkpoint', help='Prompt text')
+ parser.add_argument('prompt', help='Prompt text.')
+ parser.add_argument('checkpoint', help='Path to LoRA weight.')
parser.add_argument(
'--sdmodel',
help='Stable Diffusion model name',