From 0e0697ed7abc38d7ac65ad01528bcc117c4993f1 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 23 Nov 2023 23:08:20 +0000 Subject: [PATCH 1/2] Support diffusers IP-Adapter --- configs/ip_adapter/README.md | 58 +- .../engine/hooks/ip_adapter_save_hook.py | 30 +- diffengine/models/archs/__init__.py | 4 - diffengine/models/archs/ip_adapter.py | 524 +----------------- .../models/editors/ip_adapter/__init__.py | 3 - .../editors/ip_adapter/image_projection.py | 38 -- .../editors/ip_adapter/ip_adapter_pipeline.py | 267 --------- .../editors/ip_adapter/ip_adapter_xl.py | 22 +- .../test_hooks/test_ip_adapter_save_hook.py | 5 +- tests/test_models/test_archs.py | 31 +- .../test_ip_adapter_pipeline.py | 166 ------ .../test_ip_adapter_plus_pipeline.py | 166 ------ 12 files changed, 78 insertions(+), 1236 deletions(-) delete mode 100644 diffengine/models/editors/ip_adapter/image_projection.py delete mode 100644 diffengine/models/editors/ip_adapter/ip_adapter_pipeline.py delete mode 100644 tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_pipeline.py delete mode 100644 tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_plus_pipeline.py diff --git a/configs/ip_adapter/README.md b/configs/ip_adapter/README.md index 078167d..5f08269 100644 --- a/configs/ip_adapter/README.md +++ b/configs/ip_adapter/README.md @@ -40,32 +40,38 @@ $ mim train diffengine configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_ad Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffengine` module. ```py -from PIL import Image -from mmengine import Config -from mmengine.registry import init_default_scope -from mmengine.runner.checkpoint import _load_checkpoint_to_model, _load_checkpoint - -from diffengine.registry import MODELS - -init_default_scope('diffengine') - -prompt = [''] -example_image = ['https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg'] -config = 'configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py' -checkpoint = 'work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/epoch_50.pth' -device = 'cuda' - -config = Config.fromfile(config).copy() - -StableDiffuser = MODELS.build(config.model) -StableDiffuser = StableDiffuser.to(device) - -checkpoint = _load_checkpoint(checkpoint, map_location='cpu') -_load_checkpoint_to_model(StableDiffuser, checkpoint['state_dict'], - strict=False) - -image = StableDiffuser.infer(prompt, example_image=example_image, width=1024, height=1024)[0] -Image.fromarray(image).save('demo.png') +import torch +from diffusers import DiffusionPipeline, AutoencoderKL +from diffusers.utils import load_image +from transformers import CLIPVisionModelWithProjection + +prompt = '' + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="sdxl_models/image_encoder", + torch_dtype=torch.float16, +).to('cuda') +vae = AutoencoderKL.from_pretrained( + 'madebyollin/sdxl-vae-fp16-fix', + torch_dtype=torch.float16, +) +pipe = DiffusionPipeline.from_pretrained( + 'stabilityai/stable-diffusion-xl-base-1.0', + image_encoder=image_encoder, + vae=vae, torch_dtype=torch.float16) +pipe.to('cuda') +pipe.load_ip_adapter("work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/step41650", subfolder="", weight_name="ip_adapter.bin") + +image = load_image("https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg") + +image = pipe( + prompt, + ip_adapter_image=image, + height=1024, + width=1024, +).images[0] +image.save('demo.png') ``` You can see more details on [`docs/source/run_guides/run_ip_adapter.md`](../../docs/source/run_guides/run_ip_adapter.md#inference-with-diffengine). diff --git a/diffengine/engine/hooks/ip_adapter_save_hook.py b/diffengine/engine/hooks/ip_adapter_save_hook.py index 822c83f..3d0d609 100644 --- a/diffengine/engine/hooks/ip_adapter_save_hook.py +++ b/diffengine/engine/hooks/ip_adapter_save_hook.py @@ -1,12 +1,11 @@ -import os.path as osp from collections import OrderedDict +from pathlib import Path -from diffusers.loaders import LoraLoaderMixin +import torch from mmengine.hooks import Hook from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS - -from diffengine.models.archs import unet_attn_processors_state_dict +from torch import nn @HOOKS.register_module() @@ -31,21 +30,26 @@ def before_save_checkpoint(self, runner, checkpoint: dict) -> None: model = runner.model if is_model_wrapper(model): model = model.module - unet_ipadapter_layers_to_save = unet_attn_processors_state_dict( - model.unet) - ckpt_path = osp.join(runner.work_dir, f"step{runner.iter}") - LoraLoaderMixin.save_lora_weights( - ckpt_path, - unet_lora_layers=unet_ipadapter_layers_to_save, - ) - model.image_projection.save_pretrained( - osp.join(ckpt_path, "image_projection")) + ckpt_path = Path(runner.work_dir) / f"step{runner.iter}" + ckpt_path.mkdir(parents=True, exist_ok=True) + adapter_modules = torch.nn.ModuleList([ + v if isinstance(v, nn.Module) else nn.Identity( + ) for v in model.unet.attn_processors.values()]) # not save no grad key new_ckpt = OrderedDict() + proj_ckpt = OrderedDict() sd_keys = checkpoint["state_dict"].keys() for k in sd_keys: + if k.startswith("image_projection"): + new_k = k.replace( + "image_projection.", "").replace("image_embeds.", "proj.") + proj_ckpt[new_k] = checkpoint["state_dict"][k] if ".processor." in k or k.startswith("image_projection"): new_ckpt[k] = checkpoint["state_dict"][k] + torch.save({"image_proj": proj_ckpt, + "ip_adapter": adapter_modules.state_dict()}, + ckpt_path / "ip_adapter.bin") + checkpoint["state_dict"] = new_ckpt diff --git a/diffengine/models/archs/__init__.py b/diffengine/models/archs/__init__.py index 8d4c119..08a4cc7 100644 --- a/diffengine/models/archs/__init__.py +++ b/diffengine/models/archs/__init__.py @@ -1,13 +1,9 @@ from .ip_adapter import ( - set_controlnet_ip_adapter, set_unet_ip_adapter, - unet_attn_processors_state_dict, ) from .peft import create_peft_config __all__ = [ "set_unet_ip_adapter", - "set_controlnet_ip_adapter", "create_peft_config", - "unet_attn_processors_state_dict", ] diff --git a/diffengine/models/archs/ip_adapter.py b/diffengine/models/archs/ip_adapter.py index 047eaa8..8e45d96 100644 --- a/diffengine/models/archs/ip_adapter.py +++ b/diffengine/models/archs/ip_adapter.py @@ -1,473 +1,13 @@ -import torch import torch.nn.functional as F # noqa from diffusers.models.attention_processor import ( - Attention, AttnProcessor, AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, ) from torch import nn -class IPAttnProcessor(nn.Module): - """Attention processor for IP-Adapater. - - Args: - ---- - hidden_size (int): - The hidden size of the attention layer. - cross_attention_dim (int, optional): - The number of channels in the `encoder_hidden_states`. - Defaults to None. - text_context_len (int): - The context length of the text features. Defaults to 77. - """ - - def __init__(self, - hidden_size: int, - cross_attention_dim: int | None = None, - text_context_len: int = 77) -> None: - super().__init__() - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.text_context_len = text_context_len - - self.to_k_ip = nn.Linear( - cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear( - cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - temb: torch.Tensor | None = None, - scale: float = 1.0, - ) -> torch.Tensor: - """Call forward.""" - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - image_input_ndim = 4 - - if input_ndim == image_input_ndim: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, - height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None else encoder_hidden_states.shape) - attention_mask = attn.prepare_attention_mask(attention_mask, - sequence_length, - batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose( - 1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states) - - # split hidden states - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :self.text_context_len, :], - encoder_hidden_states[:, self.text_context_len:, :], - ) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == image_input_ndim: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - return hidden_states / attn.rescale_output_factor - - -class IPAttnProcessor2_0(nn.Module): # noqa - """Attention processor for IP-Adapater for PyTorch 2.0. - - Args: - ---- - hidden_size (int): - The hidden size of the attention layer. - cross_attention_dim (int, optional): - The number of channels in the `encoder_hidden_states`. - Defaults to None. - text_context_len (int): - The context length of the text features. Defaults to 77. - """ - - def __init__(self, - hidden_size: int, - cross_attention_dim: int | None = None, - text_context_len: int = 77) -> None: - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - msg = ("AttnProcessor2_0 requires PyTorch 2.0, to use it," - " please upgrade PyTorch to 2.0.") - raise ImportError(msg) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.text_context_len = text_context_len - - self.to_k_ip = nn.Linear( - cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear( - cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - temb: torch.Tensor | None = None, - scale: float = 1.0, - ) -> torch.Tensor: - """Call forward.""" - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - image_input_ndim = 4 - - if input_ndim == image_input_ndim: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, - height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None else encoder_hidden_states.shape) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, - attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose( - 1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states) - - # split hidden states - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :self.text_context_len, :], - encoder_hidden_states[:, self.text_context_len:, :], - ) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO(takuoko): add support for attn.scale when we move to Torch 2.1 # noqa - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO(takuoko): add support for attn.scale when we move to Torch 2.1 # noqa - ip_hidden_states = F.scaled_dot_product_attention( - query, - ip_key, - ip_value, - attn_mask=None, - dropout_p=0.0, - is_causal=False) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == image_input_ndim: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - return hidden_states / attn.rescale_output_factor - - -class CNAttnProcessor: - """Default processor for performing attention-related computations. - - Args: - ---- - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 4. - """ - - def __init__(self, clip_extra_context_tokens: int = 4) -> None: - self.clip_extra_context_tokens = clip_extra_context_tokens - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - temb: torch.Tensor | None = None, - scale: float = 1.0, # noqa - ) -> torch.Tensor: - """Call forward.""" - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - image_input_ndim = 4 - - if input_ndim == image_input_ndim: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, - height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None else encoder_hidden_states.shape) - attention_mask = attn.prepare_attention_mask(attention_mask, - sequence_length, - batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose( - 1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - # only use text - encoder_hidden_states = ( - encoder_hidden_states[:, :self.clip_extra_context_tokens]) - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == image_input_ndim: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - return hidden_states / attn.rescale_output_factor - - -class CNAttnProcessor2_0: # noqa - """Controlnet Attention Processor for PyTorch 2.0. - - Args: - ---- - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 4. - """ - - def __init__(self, clip_extra_context_tokens: int = 4) -> None: - if not hasattr(F, "scaled_dot_product_attention"): - msg = ("AttnProcessor2_0 requires PyTorch 2.0, to use it," - " please upgrade PyTorch to 2.0.") - raise ImportError(msg) - self.clip_extra_context_tokens = clip_extra_context_tokens - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - temb: torch.Tensor | None = None, - scale: float = 1.0, # noqa - ) -> torch.Tensor: - """Call forward.""" - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - image_input_ndim = 4 - - if input_ndim == image_input_ndim: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, - height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None else encoder_hidden_states.shape) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, - attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose( - 1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = ( - encoder_hidden_states[:, :self.clip_extra_context_tokens]) - encoder_hidden_states = attn.norm_encoder_hidden_states( - encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, - head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO(takuoko): add support for attn.scale when we move to Torch 2.1 # noqa - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == image_input_ndim: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - return hidden_states / attn.rescale_output_factor - - def set_unet_ip_adapter(unet: nn.Module) -> None: """Set IP-Adapter for Unet. @@ -476,6 +16,7 @@ def set_unet_ip_adapter(unet: nn.Module) -> None: unet (nn.Module): The unet to set IP-Adapter. """ attn_procs = {} + key_id = 1 for name in unet.attn_processors: cross_attention_dim = None if name.endswith( "attn1.processor") else unet.config.cross_attention_dim @@ -483,59 +24,26 @@ def set_unet_ip_adapter(unet: nn.Module) -> None: hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed( - unet.config.block_out_channels))[block_id] + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - - if cross_attention_dim is None: + if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") - else AttnProcessor) + AttnProcessor2_0 if hasattr( + F, "scaled_dot_product_attention") else AttnProcessor + ) attn_procs[name] = attn_processor_class() else: attn_processor_class = ( - IPAttnProcessor2_0 if hasattr( - F, "scaled_dot_product_attention") else IPAttnProcessor) + IPAdapterAttnProcessor2_0 if hasattr( + F, "scaled_dot_product_attention", + ) else IPAdapterAttnProcessor + ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim) - unet.set_attn_processor(attn_procs) - - -def set_controlnet_ip_adapter(controlnet, - clip_extra_context_tokens: int = 4) -> None: - """Set IP-Adapter for Unet. - - Args: - ---- - controlnet (nn.Module): The ControlNet to set IP-Adapter. - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 4. - """ - attn_processor_class = ( - CNAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") else CNAttnProcessor) - controlnet.set_attn_processor( - attn_processor_class( - clip_extra_context_tokens=clip_extra_context_tokens)) - - -def unet_attn_processors_state_dict(unet) -> dict[str, torch.tensor]: - """Unet attention processors state dict. - - Returns a state dict containing just the attention processor parameters. - """ - attn_processors = unet.attn_processors - - attn_processors_state_dict = {} + cross_attention_dim=cross_attention_dim, scale=1.0, + ).to(dtype=unet.dtype, device=unet.device) - for attn_processor_key, attn_processor in attn_processors.items(): - # skip 'AttnProcessor2_0' - if hasattr(attn_processor, "state_dict"): - for parameter_key, parameter in attn_processor.state_dict().items(): - attn_processors_state_dict[ - f"{attn_processor_key}.{parameter_key}"] = parameter - - return attn_processors_state_dict + key_id += 2 + unet.set_attn_processor(attn_procs) diff --git a/diffengine/models/editors/ip_adapter/__init__.py b/diffengine/models/editors/ip_adapter/__init__.py index 380725b..9e6d2c4 100644 --- a/diffengine/models/editors/ip_adapter/__init__.py +++ b/diffengine/models/editors/ip_adapter/__init__.py @@ -1,4 +1,3 @@ -from .ip_adapter_pipeline import IPAdapterXLPipeline, IPAdapterXLPlusPipeline from .ip_adapter_xl import IPAdapterXL, IPAdapterXLPlus from .ip_adapter_xl_data_preprocessor import IPAdapterXLDataPreprocessor @@ -6,6 +5,4 @@ "IPAdapterXL", "IPAdapterXLPlus", "IPAdapterXLDataPreprocessor", - "IPAdapterXLPipeline", - "IPAdapterXLPlusPipeline", ] diff --git a/diffengine/models/editors/ip_adapter/image_projection.py b/diffengine/models/editors/ip_adapter/image_projection.py deleted file mode 100644 index eae57b2..0000000 --- a/diffengine/models/editors/ip_adapter/image_projection.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from diffusers.configuration_utils import ConfigMixin -from diffusers.models.modeling_utils import ModelMixin -from torch import nn - - -class ImageProjModel(ModelMixin, ConfigMixin): - """Projection Model. - - Args: - ---- - cross_attention_dim (int): The number of channels in the - `unet.config.cross_attention_dim`. Defaults to 1024. - clip_embeddings_dim (int): The number of channels in the - `image_encoder.config.projection_dim`. Defaults to 1024. - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 4. - """ - - def __init__(self, - cross_attention_dim: int = 1024, - clip_embeddings_dim: int = 1024, - clip_extra_context_tokens: int = 4) -> None: - super().__init__() - - self.cross_attention_dim = cross_attention_dim - self.clip_extra_context_tokens = clip_extra_context_tokens - self.proj = nn.Linear( - clip_embeddings_dim, - self.clip_extra_context_tokens * cross_attention_dim) - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: - """Forward pass.""" - embeds = image_embeds - clip_extra_context_tokens = self.proj(embeds).reshape( - -1, self.clip_extra_context_tokens, self.cross_attention_dim) - return self.norm(clip_extra_context_tokens) diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_pipeline.py b/diffengine/models/editors/ip_adapter/ip_adapter_pipeline.py deleted file mode 100644 index b3e8384..0000000 --- a/diffengine/models/editors/ip_adapter/ip_adapter_pipeline.py +++ /dev/null @@ -1,267 +0,0 @@ -from typing import Optional, Union - -import numpy as np -import torch -from diffusers.pipelines.controlnet import MultiControlNetModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.utils import load_image -from mmengine.model import BaseModel -from PIL import Image -from transformers import CLIPVisionModelWithProjection - -from diffengine.models.archs import set_controlnet_ip_adapter, set_unet_ip_adapter -from diffengine.models.editors.ip_adapter.image_projection import ImageProjModel -from diffengine.models.editors.ip_adapter.resampler import Resampler -from diffengine.registry import MODELS - - -@MODELS.register_module() -class IPAdapterXLPipeline(BaseModel): - """IPAdapterXLPipeline. - - Args: - ---- - pipeline (DiffusionPipeline): diffusers pipeline - image_encoder (str, optional): Path to pretrained Image Encoder model. - Defaults to 'takuoko/IP-Adapter-XL'. - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 4. - """ - - def __init__( - self, - pipeline: DiffusionPipeline, - image_encoder: str = "takuoko/IP-Adapter-XL-test", - clip_extra_context_tokens: int = 4, - ) -> None: - self.image_encoder_name = image_encoder - self.clip_extra_context_tokens = clip_extra_context_tokens - - super().__init__() - self.pipeline = pipeline - self.prepare_model() - self.set_ip_adapter() - - @property - def device(self) -> torch.device: - """Get device information. - - Returns - ------- - torch.device: device. - """ - return next(self.parameters()).device - - def prepare_model(self) -> None: - """Prepare model for training. - - Disable gradient for some models. - """ - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( - self.image_encoder_name, subfolder="image_encoder") - self.image_projection = ImageProjModel( - cross_attention_dim=self.pipeline.unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.projection_dim, - clip_extra_context_tokens=self.clip_extra_context_tokens, - ) - - def set_ip_adapter(self) -> None: - """Set IP-Adapter for model.""" - set_unet_ip_adapter(self.pipeline.unet) - - if hasattr(self.pipeline, "controlnet"): - if isinstance(self.pipeline.controlnet, MultiControlNetModel): - for controlnet in self.pipeline.controlnet.nets: - set_controlnet_ip_adapter(controlnet, - self.clip_extra_context_tokens) - else: - set_controlnet_ip_adapter(self.pipeline.controlnet, - self.clip_extra_context_tokens) - - def _encode_image(self, image, num_images_per_prompt): - if not isinstance(image, torch.Tensor): - from transformers import CLIPImageProcessor - image_processor = CLIPImageProcessor.from_pretrained( - self.image_encoder_name, subfolder="image_processor") - image = image_processor(image, return_tensors="pt").pixel_values - - image = image.to(device=self.device) - image_embeddings = self.image_encoder(image).image_embeds - image_prompt_embeds = self.image_projection(image_embeddings) - uncond_image_prompt_embeds = self.image_projection( - torch.zeros_like(image_embeddings)) - - # duplicate image embeddings for each generation per prompt, using mps - # friendly method - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat( - 1, num_images_per_prompt, 1) - image_prompt_embeds = image_prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( - 1, num_images_per_prompt, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1) - - return image_prompt_embeds, uncond_image_prompt_embeds - - @torch.no_grad() - def infer(self, - prompt: list[str], - example_image: list[str | Image.Image], - negative_prompt: str | None = None, - height: int | None = None, - width: int | None = None, - num_inference_steps: int = 50, - output_type: str = "pil", - **kwargs) -> list[np.ndarray]: - """Inference function. - - Args: - ---- - prompt (`List[str]`): - The prompt or prompts to guide the image generation. - example_image (`List[Union[str, Image.Image]]`): - The image prompt or prompts to guide the image generation. - negative_prompt (`Optional[str]`): - The prompt or prompts to guide the image generation. - Defaults to None. - height (int, optional): - The height in pixels of the generated image. Defaults to None. - width (int, optional): - The width in pixels of the generated image. Defaults to None. - num_inference_steps (int): Number of inference steps. - Defaults to 50. - output_type (str): The output format of the generate image. - Choose between 'pil' and 'latent'. Defaults to 'pil'. - **kwargs: Other arguments. - """ - assert len(prompt) == len(example_image) - - self.pipeline.to(self.device) - self.pipeline.set_progress_bar_config(disable=True) - images = [] - for p, img in zip(prompt, example_image, strict=True): - pil_img = load_image(img) if isinstance(img, str) else img - pil_img = pil_img.convert("RGB") - - image_embeddings, uncond_image_embeddings = self._encode_image( - pil_img, num_images_per_prompt=1) - (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, - negative_pooled_prompt_embeds) = self.pipeline.encode_prompt( - p, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt) - prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1) - negative_prompt_embeds = torch.cat( - [negative_prompt_embeds, uncond_image_embeddings], dim=1) - image = self.pipeline( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=num_inference_steps, - height=height, - width=width, - output_type=output_type, - **kwargs).images[0] - if output_type == "latent": - images.append(image) - else: - images.append(np.array(image)) - - return images - - def forward( - self, - inputs: dict, # noqa - data_samples: Optional[list] = None, # noqa - mode: str = "tensor", # noqa - ) -> dict[str, torch.Tensor] | list: - """Forward pass.""" - msg = "forward is not implemented now, please use infer." - raise NotImplementedError(msg) - - def train_step(self, data, optim_wrapper_dict): # noqa - """Train step.""" - msg = "train_step is not implemented now, please use infer." - raise NotImplementedError(msg) - - def val_step(self, data: Union[tuple, dict, list]) -> list: # noqa - """Val step.""" - msg = "val_step is not implemented now, please use infer." - raise NotImplementedError(msg) - - def test_step(self, data: Union[tuple, dict, list]) -> list: # noqa - """Test step.""" - msg = "test_step is not implemented now, please use infer." - raise NotImplementedError(msg) - - -@MODELS.register_module() -class IPAdapterXLPlusPipeline(IPAdapterXLPipeline): - """IPAdapterXLPlusPipeline. - - Args: - ---- - clip_extra_context_tokens (int): The number of expansion ratio of proj - network hidden layer channels Defaults to 16. - """ - - def __init__(self, - *args, - clip_extra_context_tokens: int = 16, - **kwargs) -> None: - super().__init__( - *args, - clip_extra_context_tokens=clip_extra_context_tokens, - **kwargs) # type: ignore[misc] - - def prepare_model(self) -> None: - """Prepare model for training. - - Disable gradient for some models. - """ - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( - self.image_encoder_name, subfolder="image_encoder") - self.image_projection = Resampler( - embed_dims=self.image_encoder.config.hidden_size, - output_dims=self.pipeline.unet.config.cross_attention_dim, - hidden_dims=1280, - depth=4, - head_dims=64, - num_heads=20, - num_queries=self.clip_extra_context_tokens, - ffn_ratio=4) - - def _encode_image(self, image, num_images_per_prompt): - if not isinstance(image, torch.Tensor): - from transformers import CLIPImageProcessor - image_processor = CLIPImageProcessor.from_pretrained( - self.image_encoder_name, subfolder="image_processor") - image = image_processor(image, return_tensors="pt").pixel_values - - image = image.to(device=self.device) - image_embeddings = self.image_encoder( - image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_projection(image_embeddings) - uncond_image_embeddings = self.image_encoder( - torch.zeros_like(image), - output_hidden_states=True).hidden_states[-2] - uncond_image_prompt_embeds = self.image_projection( - uncond_image_embeddings) - - # duplicate image embeddings for each generation per prompt, using mps - # friendly method - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat( - 1, num_images_per_prompt, 1) - image_prompt_embeds = image_prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat( - 1, num_images_per_prompt, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1) - - return image_prompt_embeds, uncond_image_prompt_embeds diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py index e4479fd..e8fe657 100644 --- a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py +++ b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py @@ -3,13 +3,13 @@ import numpy as np import torch from diffusers import DiffusionPipeline +from diffusers.models.embeddings import ImageProjection from diffusers.utils import load_image from PIL import Image from torch import nn from transformers import CLIPVisionModelWithProjection from diffengine.models.archs import set_unet_ip_adapter -from diffengine.models.editors.ip_adapter.image_projection import ImageProjModel from diffengine.models.editors.ip_adapter.resampler import Resampler from diffengine.models.editors.stable_diffusion_xl import StableDiffusionXL from diffengine.registry import MODELS @@ -23,7 +23,7 @@ class IPAdapterXL(StableDiffusionXL): ---- image_encoder (str, optional): Path to pretrained Image Encoder model. Defaults to 'takuoko/IP-Adapter-XL'. - clip_extra_context_tokens (int): The number of expansion ratio of proj + num_image_text_embeds (int): The number of expansion ratio of proj network hidden layer channels Defaults to 4. unet_lora_config (dict, optional): The LoRA config dict for Unet. example. dict(type="LoRA", r=4). `type` is chosen from `LoRA`, @@ -47,7 +47,7 @@ class IPAdapterXL(StableDiffusionXL): def __init__(self, *args, image_encoder: str = "takuoko/IP-Adapter-XL-test", - clip_extra_context_tokens: int = 4, + num_image_text_embeds: int = 4, unet_lora_config: dict | None = None, text_encoder_lora_config: dict | None = None, finetune_text_encoder: bool = False, @@ -64,7 +64,7 @@ def __init__(self, "`finetune_text_encoder` should be False when training IPAdapter" self.image_encoder_name = image_encoder - self.clip_extra_context_tokens = clip_extra_context_tokens + self.num_image_text_embeds = num_image_text_embeds self.zeros_image_embeddings_prob = zeros_image_embeddings_prob super().__init__( @@ -87,10 +87,10 @@ def prepare_model(self) -> None: """ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( self.image_encoder_name, subfolder="image_encoder") - self.image_projection = ImageProjModel( + self.image_projection = ImageProjection( cross_attention_dim=self.unet.config.cross_attention_dim, - clip_embeddings_dim=self.image_encoder.config.projection_dim, - clip_extra_context_tokens=self.clip_extra_context_tokens, + image_embed_dim=self.image_encoder.config.projection_dim, + num_image_text_embeds=self.num_image_text_embeds, ) self.image_encoder.requires_grad_(requires_grad=False) super().prepare_model() @@ -303,17 +303,17 @@ class IPAdapterXLPlus(IPAdapterXL): Args: ---- - clip_extra_context_tokens (int): The number of expansion ratio of proj + num_image_text_embeds (int): The number of expansion ratio of proj network hidden layer channels Defaults to 16. """ def __init__(self, *args, - clip_extra_context_tokens: int = 16, + num_image_text_embeds: int = 16, **kwargs) -> None: super().__init__( *args, - clip_extra_context_tokens=clip_extra_context_tokens, + num_image_text_embeds=num_image_text_embeds, **kwargs) def prepare_model(self) -> None: @@ -330,7 +330,7 @@ def prepare_model(self) -> None: depth=4, head_dims=64, num_heads=20, - num_queries=self.clip_extra_context_tokens, + num_queries=self.num_image_text_embeds, ffn_ratio=4) self.image_encoder.requires_grad_(requires_grad=False) super(IPAdapterXL, self).prepare_model() diff --git a/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py b/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py index 24c2396..6c5044f 100644 --- a/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py +++ b/tests/test_engine/test_hooks/test_ip_adapter_save_hook.py @@ -67,10 +67,7 @@ def test_before_save_checkpoint(self): assert Path( osp.join(runner.work_dir, f"step{runner.iter}", - "pytorch_lora_weights.safetensors")).exists() - assert Path( - osp.join(runner.work_dir, f"step{runner.iter}/image_projection", - "diffusion_pytorch_model.safetensors")).exists() + "ip_adapter.bin")).exists() shutil.rmtree( osp.join(runner.work_dir, f"step{runner.iter}")) diff --git a/tests/test_models/test_archs.py b/tests/test_models/test_archs.py index 8b772cd..3b30aaf 100644 --- a/tests/test_models/test_archs.py +++ b/tests/test_models/test_archs.py @@ -1,19 +1,12 @@ from typing import Any import pytest -from diffusers import ControlNetModel, UNet2DConditionModel +from diffusers import UNet2DConditionModel from peft import LoHaConfig, LoKrConfig, LoraConfig from diffengine.models.archs import ( create_peft_config, - set_controlnet_ip_adapter, set_unet_ip_adapter, - unet_attn_processors_state_dict, -) -from diffengine.models.archs.ip_adapter import CNAttnProcessor, CNAttnProcessor2_0 -from diffengine.models.editors import ( - IPAdapterXL, - IPAdapterXLDataPreprocessor, ) @@ -26,28 +19,6 @@ def test_set_unet_ip_adapter(): assert any("processor.to_v_ip" in k for k in unet.state_dict()) -def test_set_controlnet_ip_adapter(): - controlnet = ControlNetModel.from_pretrained( - "hf-internal-testing/tiny-controlnet-sdxl") - assert all(not isinstance(attn_processor, CNAttnProcessor) - and not isinstance(attn_processor, CNAttnProcessor2_0) - for attn_processor in (controlnet.attn_processors.values())) - set_controlnet_ip_adapter(controlnet) - assert any( - isinstance(attn_processor, CNAttnProcessor | CNAttnProcessor2_0) - for attn_processor in (controlnet.attn_processors.values())) - - -def test_unet_ip_adapter_layers_to_save(): - model = IPAdapterXL( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe", - image_encoder="hf-internal-testing/unidiffuser-diffusers-test", - data_preprocessor=IPAdapterXLDataPreprocessor()) - - unet_lora_layers_to_save = unet_attn_processors_state_dict(model.unet) - assert len(unet_lora_layers_to_save) > 0 - - def test_create_peft_config(): config: dict[str, Any] = dict( type="Dummy", diff --git a/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_pipeline.py b/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_pipeline.py deleted file mode 100644 index 76b01e6..0000000 --- a/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_pipeline.py +++ /dev/null @@ -1,166 +0,0 @@ -from unittest import TestCase - -import pytest -import torch -from diffusers import ( - ControlNetModel, - DiffusionPipeline, - StableDiffusionXLControlNetPipeline, -) -from diffusers.utils import load_image -from mmengine.optim import OptimWrapper -from torch.optim import SGD - -from diffengine.models.archs.ip_adapter import ( - CNAttnProcessor, - CNAttnProcessor2_0, - IPAttnProcessor, - IPAttnProcessor2_0, -) -from diffengine.models.editors import IPAdapterXLPipeline - - -class TestIPAdapterXL(TestCase): - - def test_infer(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - # test device - assert StableDiffuser.device.type == "cpu" - - # test infer with negative_prompt - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - negative_prompt="noise", - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - # output_type = 'latent' - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - output_type="latent", - height=64, - width=64) - assert len(result) == 1 - assert type(result[0]) == torch.Tensor - assert result[0].shape == (4, 32, 32) - - def test_infer_controlnet(self): - controlnet = ControlNetModel.from_pretrained( - "hf-internal-testing/tiny-controlnet-sdxl") - pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe", - controlnet=controlnet) - StableDiffuser = IPAdapterXLPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - assert any( - isinstance(attn_processor, CNAttnProcessor | CNAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.controlnet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - image=load_image("tests/testdata/color.jpg").resize((64, 64)), - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - def test_infer_multi_controlnet(self): - controlnet = ControlNetModel.from_pretrained( - "hf-internal-testing/tiny-controlnet-sdxl") - pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe", - controlnet=[controlnet, controlnet]) - StableDiffuser = IPAdapterXLPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - for controlnet in StableDiffuser.pipeline.controlnet.nets: - assert any( - isinstance(attn_processor, CNAttnProcessor - | CNAttnProcessor2_0) - for attn_processor in (controlnet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - image=[ - load_image("tests/testdata/color.jpg").resize((64, 64)), - load_image("tests/testdata/color.jpg").resize((64, 64)), - ], - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - def test_train_step(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - # test train step - data = dict( - inputs=dict( - img=[torch.zeros((3, 64, 64))], - text=["a dog"], - clip_img=[torch.zeros((3, 32, 32))], - time_ids=[torch.zeros((1, 6))])) - optimizer = SGD(StableDiffuser.parameters(), lr=0.1) - optim_wrapper = OptimWrapper(optimizer) - with pytest.raises(NotImplementedError, match="train_step is not"): - StableDiffuser.train_step(data, optim_wrapper) - - def test_val_and_test_step(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - # test val_step - with pytest.raises(NotImplementedError, match="val_step is not"): - StableDiffuser.val_step(torch.zeros((1, ))) - - # test test_step - with pytest.raises(NotImplementedError, match="test_step is not"): - StableDiffuser.test_step(torch.zeros((1, ))) diff --git a/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_plus_pipeline.py b/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_plus_pipeline.py deleted file mode 100644 index 195d854..0000000 --- a/tests/test_models/test_editors/test_ip_adapter/test_ip_adapter_plus_pipeline.py +++ /dev/null @@ -1,166 +0,0 @@ -from unittest import TestCase - -import pytest -import torch -from diffusers import ( - ControlNetModel, - DiffusionPipeline, - StableDiffusionXLControlNetPipeline, -) -from diffusers.utils import load_image -from mmengine.optim import OptimWrapper -from torch.optim import SGD - -from diffengine.models.archs.ip_adapter import ( - CNAttnProcessor, - CNAttnProcessor2_0, - IPAttnProcessor, - IPAttnProcessor2_0, -) -from diffengine.models.editors import IPAdapterXLPlusPipeline - - -class TestIPAdapterXL(TestCase): - - def test_infer(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPlusPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - # test device - assert StableDiffuser.device.type == "cpu" - - # test infer with negative_prompt - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - negative_prompt="noise", - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - # output_type = 'latent' - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - output_type="latent", - height=64, - width=64) - assert len(result) == 1 - assert type(result[0]) == torch.Tensor - assert result[0].shape == (4, 32, 32) - - def test_infer_controlnet(self): - controlnet = ControlNetModel.from_pretrained( - "hf-internal-testing/tiny-controlnet-sdxl") - pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe", - controlnet=controlnet) - StableDiffuser = IPAdapterXLPlusPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - assert any( - isinstance(attn_processor, CNAttnProcessor | CNAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.controlnet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - image=load_image("tests/testdata/color.jpg").resize((64, 64)), - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - def test_infer_multi_controlnet(self): - controlnet = ControlNetModel.from_pretrained( - "hf-internal-testing/tiny-controlnet-sdxl") - pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe", - controlnet=[controlnet, controlnet]) - StableDiffuser = IPAdapterXLPlusPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - assert any( - isinstance(attn_processor, IPAttnProcessor | IPAttnProcessor2_0) - for attn_processor in ( - StableDiffuser.pipeline.unet.attn_processors.values())) - - for controlnet in StableDiffuser.pipeline.controlnet.nets: - assert any( - isinstance(attn_processor, CNAttnProcessor - | CNAttnProcessor2_0) - for attn_processor in (controlnet.attn_processors.values())) - - # test infer - result = StableDiffuser.infer( - ["an insect robot preparing a delicious meal"], - ["tests/testdata/color.jpg"], - image=[ - load_image("tests/testdata/color.jpg").resize((64, 64)), - load_image("tests/testdata/color.jpg").resize((64, 64)), - ], - height=64, - width=64) - assert len(result) == 1 - assert result[0].shape == (64, 64, 3) - - def test_train_step(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPlusPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - # test train step - data = dict( - inputs=dict( - img=[torch.zeros((3, 64, 64))], - text=["a dog"], - clip_img=[torch.zeros((3, 32, 32))], - time_ids=[torch.zeros((1, 6))])) - optimizer = SGD(StableDiffuser.parameters(), lr=0.1) - optim_wrapper = OptimWrapper(optimizer) - with pytest.raises(NotImplementedError, match="train_step is not"): - StableDiffuser.train_step(data, optim_wrapper) - - def test_val_and_test_step(self): - pipeline = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-xl-pipe") - StableDiffuser = IPAdapterXLPlusPipeline( - pipeline, - image_encoder="hf-internal-testing/unidiffuser-diffusers-test") - - # test val_step - with pytest.raises(NotImplementedError, match="val_step is not"): - StableDiffuser.val_step(torch.zeros((1, ))) - - # test test_step - with pytest.raises(NotImplementedError, match="test_step is not"): - StableDiffuser.test_step(torch.zeros((1, ))) From a5b5579ee57d30a047b9054e376a92895b5e4dd2 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 23 Nov 2023 23:09:46 +0000 Subject: [PATCH 2/2] Support diffusers IP-Adapter --- docs/source/run_guides/run_ip_adapter.md | 66 ++++++++++++------------ tools/demo_diffengine.py | 52 ------------------- 2 files changed, 32 insertions(+), 86 deletions(-) delete mode 100644 tools/demo_diffengine.py diff --git a/docs/source/run_guides/run_ip_adapter.md b/docs/source/run_guides/run_ip_adapter.md index dcd824d..33c235c 100644 --- a/docs/source/run_guides/run_ip_adapter.md +++ b/docs/source/run_guides/run_ip_adapter.md @@ -40,40 +40,38 @@ $ mim train diffengine configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_ad Once you have trained a model, specify the path to the saved model and utilize it for inference using the `diffengine` module. ```py -from PIL import Image -from mmengine import Config -from mmengine.registry import init_default_scope -from mmengine.runner.checkpoint import _load_checkpoint_to_model, _load_checkpoint - -from diffengine.registry import MODELS - -init_default_scope('diffengine') - -prompt = [''] -example_image = ['https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg'] -config = 'configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py' -checkpoint = 'work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/epoch_50.pth' -device = 'cuda' - -config = Config.fromfile(config).copy() - -StableDiffuser = MODELS.build(config.model) -StableDiffuser = StableDiffuser.to(device) - -checkpoint = _load_checkpoint(checkpoint, map_location='cpu') -_load_checkpoint_to_model(StableDiffuser, checkpoint['state_dict'], - strict=False) - -image = StableDiffuser.infer(prompt, example_image=example_image, width=1024, height=1024)[0] -Image.fromarray(image).save('demo.png') -``` - -We also provide inference demo scripts: - -``` -$ mim run diffengine demo_diffengine ${PROMPT} ${CONFIG} ${CHECKPOINT} --height 1024 --width 1024 --example-image ${EXAMPLE_IMAGE} -# Example -$ mim run diffengine demo_diffengine "" configs/ip_adapter/stable_diffusion_xl_pokemon_blip_ip_adapter.py work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/epoch_50.pth --height 1024 --width 1024 --example-image https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg +import torch +from diffusers import DiffusionPipeline, AutoencoderKL +from diffusers.utils import load_image +from transformers import CLIPVisionModelWithProjection + +prompt = '' + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="sdxl_models/image_encoder", + torch_dtype=torch.float16, +).to('cuda') +vae = AutoencoderKL.from_pretrained( + 'madebyollin/sdxl-vae-fp16-fix', + torch_dtype=torch.float16, +) +pipe = DiffusionPipeline.from_pretrained( + 'stabilityai/stable-diffusion-xl-base-1.0', + image_encoder=image_encoder, + vae=vae, torch_dtype=torch.float16) +pipe.to('cuda') +pipe.load_ip_adapter("work_dirs/stable_diffusion_xl_pokemon_blip_ip_adapter/step41650", subfolder="", weight_name="ip_adapter.bin") + +image = load_image("https://datasets-server.huggingface.co/assets/lambdalabs/pokemon-blip-captions/--/default/train/0/image/image.jpg") + +image = pipe( + prompt, + ip_adapter_image=image, + height=1024, + width=1024, +).images[0] +image.save('demo.png') ``` ## Results Example diff --git a/tools/demo_diffengine.py b/tools/demo_diffengine.py deleted file mode 100644 index c3d7e35..0000000 --- a/tools/demo_diffengine.py +++ /dev/null @@ -1,52 +0,0 @@ -from argparse import ArgumentParser - -from mmengine import Config -from mmengine.registry import init_default_scope -from mmengine.runner.checkpoint import _load_checkpoint, _load_checkpoint_to_model -from PIL import Image - -from diffengine.registry import MODELS - -init_default_scope("diffengine") - - -def main() -> None: - parser = ArgumentParser() - parser.add_argument("prompt", help="Prompt text.") - parser.add_argument("config", help="Path to config file.") - parser.add_argument("checkpoint", help="Path to weight file.") - parser.add_argument("--out", help="Output path", default="demo.png") - parser.add_argument( - "--height", - help="The height for output images.", - default=None, - type=int) - parser.add_argument( - "--width", help="The width for output images.", default=None, type=int) - parser.add_argument( - "--example-image", - help="Path to example image for generation.", - type=str, - default=None) - parser.add_argument( - "--device", help="Device used for inference", default="cuda") - args = parser.parse_args() - - config = Config.fromfile(args.config).copy() - - stable_diffuser = MODELS.build(config.model) - stable_diffuser = stable_diffuser.to(args.device) - - checkpoint = _load_checkpoint(args.checkpoint, map_location="cpu") - _load_checkpoint_to_model( - stable_diffuser, checkpoint["state_dict"], strict=False) - - kwargs = {} - if args.example_image is not None: - kwargs["example_image"] = [args.example_image] - image = stable_diffuser.infer([args.prompt], **kwargs)[0] - Image.fromarray(image).save(args.out) - - -if __name__ == "__main__": - main()