Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support InstructPix2Pix #89

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
- [Run Stable Diffusion XL ControlNet](https://diffengine.readthedocs.io/en/latest/run_guides/run_controlnet_xl.html)
- [Run IP Adapter](https://diffengine.readthedocs.io/en/latest/run_guides/run_ip_adapter.html)
- [Run T2I Adapter](https://diffengine.readthedocs.io/en/latest/run_guides/run_t2i_adapter.html)
- [Run InstructPix2Pix](https://diffengine.readthedocs.io/en/latest/run_guides/run_instruct_pix2pix.html)
- [Inference](https://diffengine.readthedocs.io/en/latest/run_guides/inference.html)

</details>
Expand Down Expand Up @@ -189,6 +190,7 @@ For detailed user guides and advanced guides, please refer to our [Documentation
<li><a href="configs/ip_adapter/README.md">IP-Adapter (2023)</a></li>
<li><a href="configs/esd/README.md">Erasing Concepts from Diffusion Models (2023)</a></li>
<li><a href="configs/ssd_1b/README.md">SSD-1B (2023)</a></li>
<li><a href="configs/instruct_pix2pix/README.md">InstructPix2Pix (2022)</a></li>
</ul>
</td>
<td>
Expand Down
47 changes: 47 additions & 0 deletions configs/_base_/datasets/instructpix2pix_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
train_pipeline = [
dict(type="SaveImageShape"),
dict(
type="torchvision/Resize",
size=1024,
interpolation="bilinear",
keys=["img", "condition_img"]),
dict(type="RandomCrop", size=1024, keys=["img", "condition_img"]),
dict(type="RandomHorizontalFlip", p=0.5, keys=["img", "condition_img"]),
dict(type="ComputeTimeIds"),
dict(type="torchvision/ToTensor", keys=["img", "condition_img"]),
dict(type="DumpImage", max_imgs=10, dump_dir="work_dirs/dump"),
dict(type="torchvision/Normalize", mean=[0.5], std=[0.5],
keys=["img", "condition_img"]),
dict(
type="PackInputs",
input_keys=["img", "condition_img", "text", "time_ids"]),
]
train_dataloader = dict(
batch_size=1,
num_workers=4,
dataset=dict(
type="HFControlNetDataset",
dataset="fusing/instructpix2pix-1000-samples",
image_column="edited_image",
condition_column="input_image",
caption_column="edit_prompt",
pipeline=train_pipeline),
sampler=dict(type="DefaultSampler", shuffle=True),
)

val_dataloader = None
val_evaluator = None
test_dataloader = val_dataloader
test_evaluator = val_evaluator

custom_hooks = [
dict(
type="VisualizationHook",
prompt=["make the mountains snowy"] * 4,
condition_image=[
'https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png' # noqa
] * 4,
height=1024,
width=1024),
dict(type="SDCheckpointHook"),
]
5 changes: 5 additions & 0 deletions configs/_base_/models/stable_diffusion_xl_instruct_pix2pix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model = dict(
type="StableDiffusionXLInstructPix2Pix",
model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
gradient_checkpointing=True)
88 changes: 88 additions & 0 deletions configs/instruct_pix2pix/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# InstructPix2Pix

[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://arxiv.org/abs/2211.09800)

## Abstract

We propose a method for editing images from human instructions: given an input image and a written instruction that tells the model what to do, our model follows these instructions to edit the image. To obtain training data for this problem, we combine the knowledge of two large pretrained models -- a language model (GPT-3) and a text-to-image model (Stable Diffusion) -- to generate a large dataset of image editing examples. Our conditional diffusion model, InstructPix2Pix, is trained on our generated data, and generalizes to real images and user-written instructions at inference time. Since it performs edits in the forward pass and does not require per example fine-tuning or inversion, our model edits images quickly, in a matter of seconds. We show compelling editing results for a diverse collection of input images and written instructions.

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/b9de262c-e316-4df2-88d7-690f863934e3"/>
</div>

## Citation

```
@article{brooks2022instructpix2pix,
title={InstructPix2Pix: Learning to Follow Image Editing Instructions},
author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A},
journal={arXiv preprint arXiv:2211.09800},
year={2022}
}
```

## Run Training

Run Training

```
# single gpu
$ mim train diffengine ${CONFIG_FILE}
# multi gpus
$ mim train diffengine ${CONFIG_FILE} --gpus 2 --launcher pytorch

# Example.
$ mim train diffengine configs/instruct_pix2pix/stable_diffusion_xl_instruct_pix2pix.py
```

## Inference with diffusers

Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffusers.pipeline` module.

Before inferencing, we should convert weights for diffusers format,

```bash
$ mim run diffengine publish_model2diffusers ${CONFIG_FILE} ${INPUT_FILENAME} ${OUTPUT_DIR} --save-keys ${SAVE_KEYS}
# Example
$ mim run diffengine publish_model2diffusers configs/instruct_pix2pix/stable_diffusion_xl_instruct_pix2pix.py work_dirs/stable_diffusion_xl_instruct_pix2pix/epoch_3.pth work_dirs/stable_diffusion_xl_instruct_pix2pix --save-keys unet
```

Then we can run inference.

```py
import torch
from diffusers import StableDiffusionXLInstructPix2PixPipeline, UNet2DConditionModel, AutoencoderKL
from diffusers.utils import load_image

checkpoint = 'work_dirs/stable_diffusion_xl_instruct_pix2pix'
prompt = 'make the mountains snowy'
condition_image = load_image(
'https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png'
).resize((1024, 1024))

unet = UNet2DConditionModel.from_pretrained(
checkpoint, subfolder='unet', torch_dtype=torch.float16)

vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0', unet=unet, vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')

image = pipe(
prompt,
image=condition_image,
num_inference_steps=50,
).images[0]
image.save('demo.png')
```

## Results Example

#### stable_diffusion_xl_instruct_pix2pix

![input1](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png)

![example1](https://github.com/okotaku/diffengine/assets/24734142/f66149fd-e375-4f85-bfbf-d4d046cd469a)
16 changes: 16 additions & 0 deletions configs/instruct_pix2pix/stable_diffusion_xl_instruct_pix2pix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = [
"../_base_/models/stable_diffusion_xl_instruct_pix2pix.py",
"../_base_/datasets/instructpix2pix_xl.py",
"../_base_/schedules/stable_diffusion_3e.py",
"../_base_/default_runtime.py",
]

optim_wrapper = dict(
optimizer=dict(
_delete_=True,
type="Adafactor",
lr=3e-5,
weight_decay=1e-2,
scale_parameter=False,
relative_step=False),
accumulative_counts=4)
1 change: 1 addition & 0 deletions diffengine/models/editors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .deepfloyd_if import * # noqa: F403
from .distill_sd import * # noqa: F403
from .esd import * # noqa: F403
from .instruct_pix2pix import * # noqa: F403
from .ip_adapter import * # noqa: F403
from .ssd_1b import * # noqa: F403
from .stable_diffusion import * # noqa: F403
Expand Down
3 changes: 3 additions & 0 deletions diffengine/models/editors/instruct_pix2pix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .instruct_pix2pix_xl import StableDiffusionXLInstructPix2Pix

__all__ = ["StableDiffusionXLInstructPix2Pix"]
Loading