Skip to content

Commit

Permalink
Merge pull request #115 from okotaku/feat/finetune_pretrained_ip_adapter
Browse files Browse the repository at this point in the history
[Feature] Support Finetune pretrained IP-Adapter
  • Loading branch information
okotaku authored Dec 19, 2023
2 parents 6ea30d5 + 5a71bf8 commit 0ed5f89
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 99 deletions.
6 changes: 6 additions & 0 deletions configs/ip_adapter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,9 @@ You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../d
![input1](https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg)

![example1](https://github.com/okotaku/diffengine/assets/24734142/723ad39d-9e0f-441b-80f7-cf9bcfd12853)

#### stable_diffusion_xl_pokemon_blip_ip_adapter_plus_pretrained

![input1](https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg)

![example1](https://github.com/okotaku/diffengine/assets/24734142/ace81220-010b-44a5-aa8f-3acdf3f54433)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = [
"../_base_/models/stable_diffusion_xl_ip_adapter_plus.py",
"../_base_/datasets/pokemon_blip_xl_ip_adapter.py",
"../_base_/schedules/stable_diffusion_xl_50e.py",
"../_base_/default_runtime.py",
]

model = dict(image_encoder_sub_folder="models/image_encoder",
pretrained_adapter="h94/IP-Adapter",
pretrained_adapter_subfolder="sdxl_models",
pretrained_adapter_weights_name="ip-adapter-plus_sdxl_vit-h.bin")

train_dataloader = dict(batch_size=1)

optim_wrapper = dict(accumulative_counts=4) # update every four times

train_cfg = dict(by_epoch=True, max_epochs=10)
18 changes: 6 additions & 12 deletions diffengine/engine/hooks/ip_adapter_save_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS
from torch import nn

from diffengine.models.archs import process_ip_adapter_state_dict


@HOOKS.register_module()
Expand Down Expand Up @@ -33,23 +34,16 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None:

ckpt_path = Path(runner.work_dir) / f"step{runner.iter}"
ckpt_path.mkdir(parents=True, exist_ok=True)
adapter_modules = torch.nn.ModuleList([
v if isinstance(v, nn.Module) else nn.Identity(
) for v in model.unet.attn_processors.values()])

adapter_state_dict = process_ip_adapter_state_dict(
model.unet, model.image_projection)

# not save no grad key
new_ckpt = OrderedDict()
proj_ckpt = OrderedDict()
sd_keys = checkpoint["state_dict"].keys()
for k in sd_keys:
if k.startswith("image_projection"):
new_k = k.replace(
"image_projection.", "").replace("image_embeds.", "proj.")
proj_ckpt[new_k] = checkpoint["state_dict"][k]
if ".processor." in k or k.startswith("image_projection"):
new_ckpt[k] = checkpoint["state_dict"][k]
torch.save({"image_proj": proj_ckpt,
"ip_adapter": adapter_modules.state_dict()},
ckpt_path / "ip_adapter.bin")
torch.save(adapter_state_dict, ckpt_path / "ip_adapter.bin")

checkpoint["state_dict"] = new_ckpt
4 changes: 3 additions & 1 deletion diffengine/models/archs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .ip_adapter import (
load_ip_adapter,
process_ip_adapter_state_dict,
set_unet_ip_adapter,
)
from .peft import create_peft_config

__all__ = [
"set_unet_ip_adapter",
"set_unet_ip_adapter", "load_ip_adapter", "process_ip_adapter_state_dict",
"create_peft_config",
]
188 changes: 188 additions & 0 deletions diffengine/models/archs/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from collections import OrderedDict

import torch
import torch.nn.functional as F # noqa
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.utils import _get_model_file
from safetensors import safe_open
from torch import nn


Expand Down Expand Up @@ -47,3 +53,185 @@ def set_unet_ip_adapter(unet: nn.Module) -> None:

key_id += 2
unet.set_attn_processor(attn_procs)


def load_ip_adapter( # noqa: PLR0915, C901, PLR0912
unet: nn.Module,
image_projection: nn.Module,
pretrained_adapter: str,
subfolder: str,
weights_name: str) -> None:
"""Load IP-Adapter pretrained weights.
Reference to diffusers/loaders/ip_adapter.py. and
diffusers/loaders/unet.py.
"""
model_file = _get_model_file(
pretrained_adapter,
subfolder=subfolder,
weights_name=weights_name,
cache_dir=None,
force_download=False,
resume_download=False,
proxies=None,
local_files_only=None,
token=None,
revision=None,
user_agent={
"file_type": "attn_procs_weights",
"framework": "pytorch",
})
if weights_name.endswith(".safetensors"):
state_dict: dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f:
if key.startswith("image_proj."):
state_dict["image_proj"][
key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][
key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")

key_id = 1
for name, attn_proc in unet.attn_processors.items():
cross_attention_dim = None if name.endswith(
"attn1.processor") else unet.config.cross_attention_dim
if cross_attention_dim is None or "motion_modules" in name:
continue
value_dict = {}
for k in attn_proc.state_dict():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})

attn_proc.load_state_dict(value_dict)
key_id += 2

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
},
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
elif "proj.3.weight" in state_dict["image_proj"]:
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
},
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
else:
# IP-Adapter Plus
new_sd = OrderedDict()
for k, v in state_dict["image_proj"].items():
if "0.to" in k:
new_k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
new_k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
new_k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
new_k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
new_k = k.replace("1.3.weight", "3.1.net.2.weight")
else:
new_k = k

if "norm1" in new_k:
new_sd[new_k.replace("0.norm1", "0")] = v
elif "norm2" in new_k:
new_sd[new_k.replace("0.norm2", "1")] = v
elif "to_kv" in new_k:
v_chunk = v.chunk(2, dim=0)
new_sd[new_k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[new_k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in new_k:
new_sd[new_k.replace("to_out", "to_out.0")] = v
else:
new_sd[new_k] = v
image_projection.load_state_dict(new_sd)
del state_dict
torch.cuda.empty_cache()


def process_ip_adapter_state_dict( # noqa: PLR0915, C901, PLR0912
unet: nn.Module, image_projection: nn.Module) -> dict:
"""Process IP-Adapter state dict."""
adapter_modules = torch.nn.ModuleList([
v if isinstance(v, nn.Module) else nn.Identity(
) for v in unet.attn_processors.values()])

# not save no grad key
ip_image_projection_state_dict = OrderedDict()
if isinstance(image_projection, ImageProjection):
for k, v in image_projection.state_dict().items():
new_k = k.replace("image_embeds.", "proj.")
ip_image_projection_state_dict[new_k] = v
elif isinstance(image_projection, Resampler):
for k, v in image_projection.state_dict().items():
if "2.to" in k:
new_k = k.replace("2.to", "0.to")
elif "layers.3.0.weight" in k:
new_k = k.replace("layers.3.0.weight", "layers.3.0.norm1.weight")
elif "layers.3.0.bias" in k:
new_k = k.replace("layers.3.0.bias", "layers.3.0.norm1.bias")
elif "layers.3.1.weight" in k:
new_k = k.replace("layers.3.1.weight", "layers.3.0.norm2.weight")
elif "layers.3.1.bias" in k:
new_k = k.replace("layers.3.1.bias", "layers.3.0.norm2.bias")
elif "3.0.weight" in k:
new_k = k.replace("3.0.weight", "1.0.weight")
elif "3.0.bias" in k:
new_k = k.replace("3.0.bias", "1.0.bias")
elif "3.0.weight" in k:
new_k = k.replace("3.0.weight", "1.0.weight")
elif "3.1.net.0.proj.weight" in k:
new_k = k.replace("3.1.net.0.proj.weight", "1.1.weight")
elif "3.1.net.2.weight" in k:
new_k = k.replace("3.1.net.2.weight", "1.3.weight")
elif "layers.0.0" in k:
new_k = k.replace("layers.0.0", "layers.0.0.norm1")
elif "layers.0.1" in k:
new_k = k.replace("layers.0.1", "layers.0.0.norm2")
elif "layers.1.0" in k:
new_k = k.replace("layers.1.0", "layers.1.0.norm1")
elif "layers.1.1" in k:
new_k = k.replace("layers.1.1", "layers.1.0.norm2")
elif "layers.2.0" in k:
new_k = k.replace("layers.2.0", "layers.2.0.norm1")
elif "layers.2.1" in k:
new_k = k.replace("layers.2.1", "layers.2.0.norm2")
else:
new_k = k

if "norm_cross" in new_k:
ip_image_projection_state_dict[new_k.replace("norm_cross", "norm1")] = v
elif "layer_norm" in new_k:
ip_image_projection_state_dict[new_k.replace("layer_norm", "norm2")] = v
elif "to_k" in new_k:
ip_image_projection_state_dict[
new_k.replace("to_k", "to_kv")] = torch.cat([
v, image_projection.state_dict()[k.replace("to_k", "to_v")]], dim=0)
elif "to_v" in new_k:
continue
elif "to_out.0" in new_k:
ip_image_projection_state_dict[new_k.replace("to_out.0", "to_out")] = v
else:
ip_image_projection_state_dict[new_k] = v

return {"image_proj": ip_image_projection_state_dict,
"ip_adapter": adapter_modules.state_dict()}
Loading

0 comments on commit 0ed5f89

Please sign in to comment.