Skip to content

Commit

Permalink
IP-Adapter attention masking (#6847)
Browse files Browse the repository at this point in the history
* Add attention masking to attn processors

* Update tensor conversion

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
3 people authored Feb 19, 2024
1 parent 31de879 commit eba7e7a
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 17 deletions.
80 changes: 80 additions & 0 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,83 @@ image
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
</div>

### IP-Adapter masking

Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter.
For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided.

Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`].

> [!TIP]
> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted.
Here an example with two masks:

```py
from diffusers.image_processor import IPAdapterMaskProcessor

mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")

output_height = 1024
output_width = 1024

processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
</div>
</div>

If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter.

```py
face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")

ip_images =[[image1], [image2]]

```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image one</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image two</figcaption>
</div>
</div>

Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below:

```py

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
pipeline.set_ip_adapter_scale([0.7] * 2)
generator = torch.Generator(device="cpu").manual_seed(0)
num_images=1

image = pipeline(
prompt="2 girls",
ip_adapter_image=ip_images,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=num_images,
generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
).images[0]
```

<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
106 changes: 106 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from PIL import Image, ImageFilter, ImageOps

from .configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -882,3 +884,107 @@ def preprocess(
depth = self.binarize(depth)

return rgb, depth


class IPAdapterMaskProcessor(VaeImageProcessor):
"""
Image processor for IP Adapter image masks.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
Whether to convert the images to grayscale format.
"""

config_name = CONFIG_NAME

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
do_convert_grayscale: bool = True,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)

@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args:
mask (`torch.FloatTensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
num_queries (`int`):
The number of queries.
value_embed_dim (`int`):
The dimensionality of the value embeddings.
Returns:
`torch.FloatTensor`:
The downsampled mask tensor.
"""
o_h = mask.shape[1]
o_w = mask.shape[2]
ratio = o_w / o_h
mask_h = int(math.sqrt(num_queries / ratio))
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
mask_w = num_queries // mask_h

mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)

# Repeat batch_size times
if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)

mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)

downsampled_area = mask_h * mask_w
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
if downsampled_area < num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
if downsampled_area > num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = mask_downsample[:, :num_queries]

# Repeat last dimension to match SDPA output shape
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1, 1, value_embed_dim
)

return mask_downsample
79 changes: 63 additions & 16 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn.functional as F
from torch import nn

from ..image_processor import IPAdapterMaskProcessor
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -2107,12 +2108,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale

def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

Expand Down Expand Up @@ -2167,9 +2169,22 @@ def __call__(
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
Expand All @@ -2181,6 +2196,15 @@ def __call__(
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
Expand Down Expand Up @@ -2244,12 +2268,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale

def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

Expand Down Expand Up @@ -2318,9 +2343,22 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
Expand All @@ -2339,6 +2377,15 @@ def __call__(
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
Expand Down
Loading

0 comments on commit eba7e7a

Please sign in to comment.