Skip to content

Commit

Permalink
Merge pull request #99 from okotaku/feat/update_diffusers_ip_adapter
Browse files Browse the repository at this point in the history
[Feature] Update diffusers IP Adapter
  • Loading branch information
okotaku authored Nov 23, 2023
2 parents 6a4929b + a5b5579 commit 2888747
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 1,322 deletions.
58 changes: 32 additions & 26 deletions configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,38 @@ $ mim train diffengine configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_ad
Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffengine` module.

```py
from PIL import Image
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.runner.checkpoint import _load_checkpoint_to_model, _load_checkpoint

from diffengine.registry import MODELS

init_default_scope('diffengine')

prompt = ['']
example_image = ['https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg']
config = 'configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py'
checkpoint = 'work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/epoch_50.pth'
device = 'cuda'

config = Config.fromfile(config).copy()

StableDiffuser = MODELS.build(config.model)
StableDiffuser = StableDiffuser.to(device)

checkpoint = _load_checkpoint(checkpoint, map_location='cpu')
_load_checkpoint_to_model(StableDiffuser, checkpoint['state_dict'],
strict=False)

image = StableDiffuser.infer(prompt, example_image=example_image, width=1024, height=1024)[0]
Image.fromarray(image).save('demo.png')
import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection

prompt = ''

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="sdxl_models/image_encoder",
torch_dtype=torch.float16,
).to('cuda')
vae = AutoencoderKL.from_pretrained(
'madebyollin/sdxl-vae-fp16-fix',
torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0',
image_encoder=image_encoder,
vae=vae, torch_dtype=torch.float16)
pipe.to('cuda')
pipe.load_ip_adapter("work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/step41650", subfolder="", weight_name="ip_adapter.bin")

image = load_image("https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg")

image = pipe(
prompt,
ip_adapter_image=image,
height=1024,
width=1024,
).images[0]
image.save('demo.png')
```

You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../docs/source/run_guides/run_ip_adapter.md#inference-with-diffengine).
Expand Down
30 changes: 17 additions & 13 deletions diffengine/engine/hooks/ip_adapter_save_hook.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os.path as osp
from collections import OrderedDict
from pathlib import Path

from diffusers.loaders import LoraLoaderMixin
import torch
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS

from diffengine.models.archs import unet_attn_processors_state_dict
from torch import nn


@HOOKS.register_module()
Expand All @@ -31,21 +30,26 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
unet_ipadapter_layers_to_save = unet_attn_processors_state_dict(
model.unet)
ckpt_path = osp.join(runner.work_dir, f"step{runner.iter}")
LoraLoaderMixin.save_lora_weights(
ckpt_path,
unet_lora_layers=unet_ipadapter_layers_to_save,
)

model.image_projection.save_pretrained(
osp.join(ckpt_path, "image_projection"))
ckpt_path = Path(runner.work_dir) / f"step{runner.iter}"
ckpt_path.mkdir(parents=True, exist_ok=True)
adapter_modules = torch.nn.ModuleList([
v if isinstance(v, nn.Module) else nn.Identity(
) for v in model.unet.attn_processors.values()])

# not save no grad key
new_ckpt = OrderedDict()
proj_ckpt = OrderedDict()
sd_keys = checkpoint["state_dict"].keys()
for k in sd_keys:
if k.startswith("image_projection"):
new_k = k.replace(
"image_projection.", "").replace("image_embeds.", "proj.")
proj_ckpt[new_k] = checkpoint["state_dict"][k]
if ".processor." in k or k.startswith("image_projection"):
new_ckpt[k] = checkpoint["state_dict"][k]
torch.save({"image_proj": proj_ckpt,
"ip_adapter": adapter_modules.state_dict()},
ckpt_path / "ip_adapter.bin")

checkpoint["state_dict"] = new_ckpt
4 changes: 0 additions & 4 deletions diffengine/models/archs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from .ip_adapter import (
set_controlnet_ip_adapter,
set_unet_ip_adapter,
unet_attn_processors_state_dict,
)
from .peft import create_peft_config

__all__ = [
"set_unet_ip_adapter",
"set_controlnet_ip_adapter",
"create_peft_config",
"unet_attn_processors_state_dict",
]
Loading

0 comments on commit 2888747

Please sign in to comment.