Skip to content

Commit

Permalink
IP adapter support for most pipelines (huggingface#5900)
Browse files Browse the repository at this point in the history
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
3 people authored Dec 10, 2023
1 parent 815ef75 commit 5c9bffc
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
from ...utils import (
Expand Down Expand Up @@ -129,7 +129,7 @@ def retrieve_timesteps(


class LatentConsistencyModelImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for image-to-image generation using a latent consistency model.
Expand All @@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Expand All @@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -179,6 +180,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand All @@ -191,6 +193,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

if safety_checker is None and requires_safety_checker:
Expand Down Expand Up @@ -449,6 +452,31 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -647,6 +675,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -695,6 +724,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -758,6 +789,12 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -815,6 +852,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -829,6 +869,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
from ...utils import (
Expand Down Expand Up @@ -107,7 +107,7 @@ def retrieve_timesteps(


class LatentConsistencyModelPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using a latent consistency model.
Expand All @@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Expand All @@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -157,6 +158,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Expand Down Expand Up @@ -433,6 +436,31 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -581,6 +609,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -629,6 +658,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -697,6 +728,12 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -748,6 +785,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -762,6 +802,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Loading

0 comments on commit 5c9bffc

Please sign in to comment.