Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP-Adapter attention masking #6847

Merged
merged 27 commits into from
Feb 19, 2024
Merged

Conversation

fabiorigano
Copy link
Contributor

What does this PR do?

Fixes #6802

Who can review?

@yiyixuxu @asomoza

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Feb 4, 2024

it is a work in progress, I am not satisfied with the results (maybe I am doing something wrong).

Mask preprocessing is done outside of the PR. I extract masks from a RGB image, after selecting unique colors and discarding the background (black). Here it is a code snippet to get the list of masks from the following image:

import torch
import diffusers
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from diffusers.utils import load_image

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "SG161222/Realistic_Vision_V4.0_noVAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    feature_extractor=None,
    safety_checker=None
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)


# Load image
mask = load_image("./mask.png")
# Use image processor registered in the pipeline
iproc = pipeline.image_processor
mask = iproc.pil_to_numpy(mask)[0]
# Find unique colors
colors = np.unique(mask.reshape(-1, 3), axis=0)
# Discard background
unique = [colors[i] for i in range(colors.shape[0]) if np.all(colors[i] != np.zeros(3))]
# Extract masks
masks = [np.expand_dims(np.where(mask==u, 1,0)[:, :, 0], axis=0) for u in unique]
masks = [iproc.numpy_to_pt(mask)[0] for mask in masks]

mask

Input images are:

https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png
image1

https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png
image2

Then I called the pipeline as it follows:

generator = torch.Generator(device="cpu").manual_seed(33)
num_images=1

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1, image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks},
      #output_type= "np"
  ).images

Result without masks:
without

Result with masks:
masked

@asomoza
Copy link
Member

asomoza commented Feb 5, 2024

Nice work, you're doing the one use case that I didn't code which is IP Adapters with multiple images and multiple masks, but is the same as two IP Adapters with one image and one mask for each one with the added benefit that you can manage the weight of each one separately, so in my tests it would be like this:

Result 1 Result 2
20240204223825 20240204223854

I use SDXL only, but they should be comparable. I really recommend that you don’t use multiple masks for multiple images and instead use one mask per IP Adapter. I haven’t seen someone using this, but I could be wrong.

The problem you see in your example is more noticeable with SDXL:

Result 1 Result 2
20240205030442 20240205030523

What's happening is that you're matching the batch with the masks, but the batch, depending on the classifier free guidance is * 2 or not, so what you're really doing is applying only one mask if the negative prompt is empty or deleting one if the CFG is less than 1. Also you're applying the mask to the ip_hidden_states of multiple images, so you can also see that the faces are combined into one where the mask is applied.

There's some more minor issues but I'll wait and see which approach you use.

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Feb 5, 2024

hi @asomoza, thanks for the suggestion, I updated the for loop and now results look pretty good.

Also you're applying the mask to the ip_hidden_states of multiple images, so you can also see that the faces are combined into one where the mask is applied.

I am not sure about what you mean here. The image after the mask in the first comment is the result of generation without applying masks, so it is correct to have a combination of the two faces.

I changed the base SD model and loaded two IP-Adapters to the pipeline:

pipeline = AutoPipelineForText2Image.from_pretrained(
    "frankjoshua/realisticVisionV51_v51VAE",
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    feature_extractor=None,
    safety_checker=None
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter-plus-face_sd15.bin", "ip-adapter-full-face_sd15.bin"])

pipeline.set_ip_adapter_scale([0.7, 0.7])

generator = torch.Generator(device="cpu").manual_seed(33)
num_images=4

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1], [image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks}
  ).images

Output
res1
res0
res3

@yiyixuxu

@asomoza
Copy link
Member

asomoza commented Feb 5, 2024

yeah, now is working ok, nice work.

I am not sure about what you mean here. The image after the mask in the first comment is the result of generation without applying masks, so it is correct to have a combination of the two faces.

I meant the one after where there was supposed to be one face for each woman, also you can see it in my results, that's because there were multiple images for one IP Adapter and you were applying one mask to those.

You don't have that problem now and it doesn't matter anymore since you're using two IP Adapters, but the equivalent would be if you do this:

images = pipeline(
      prompt="A photo of two girls wearing black dresses, holding red roses in hand, upper body, behind is the Eiffel Tower",
      ip_adapter_image=[[image1, image2], [image2]],
      negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      num_inference_steps=20, num_images_per_prompt=num_images, width=704, height=512,
      generator=generator, cross_attention_kwargs={"masks": masks}
  ).images

The results are like this:

[[image1], [image2]] [[image1, image2], [image2]]
20240205114207 20240205115321

I know they're similar but I can see the difference instantly since I've done a million of tests with IP Adapters.

batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this throws an error if you use a mask with different width and height than the generated image, for example if I use your mask with SDXL and generate a 1024x1024 image I get this error:

The size of tensor a (4096) must match the size of tensor b (4070) at non-singleton dimension 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I didn't add checks on mask size yet. I think also ComfyUI implementation has the same issue, but I haven't tested it
https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/90d3451cd970d5aa9cac55224e24a7c7fd98d253/IPAdapterPlus.py#L537

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it works with masks that aren't of the same ratio as the generation, is just not recommended. Maybe @cubiq can provide his insights here, I use the same code and it doesn't use the ratio, I think the need checking means that he wasn't completely sure of the formula used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will check without ratio as in the other implementation! Thanks

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention mask is resized and stretched at each iteration, the aspect ratio doesn't matter but of course it's better if you provide the right size.

due to rounding error it might happen that you get the wrong size, but it's not very common and I think I have a solution for that already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm the issue is still there also with the other implementation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I really don't know what should be the best method of doing this that's consistent with diffusers.

In my case I prepare the mask latents outside the attention processor with the vae scale factor and the width and height of the generated image but it could be as simple as throwing an error telling the user that the masks must have the same aspect ratio than the generated image.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 5, 2024

Great work! Thanks everyone here ❤️ the results look super cool to me!
Can we confirm that it works correctly as long as we only pass one image and one mask for each ip-adapter? @asomoza @fabiorigano

so the remaining item is:

  1. the resizing IP-Adapter attention masking  #6847 (comment)
  2. refactor the code

@asomoza
Copy link
Member

asomoza commented Feb 6, 2024

yes, it works correctly but with one or multiple prompt images and one mask per IP Adapter which IMO is the correct implementation.

There's one other issue that maybe should be addressed but I don't know if it's from this PR or comes from before, but if you don't pass the same number of scales it completely ignores the IP adapters that don't have scales without showing a message or error.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 7, 2024

@asomoza I fixed here #6884

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super cool!

src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
src/diffusers/models/attention_processor.py Outdated Show resolved Hide resolved
if len(masks) != len(ip_hidden_states):
raise ValueError(
f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, it only works when we pass 1 image / 1 mask /1 ip-adapter?, if so, let's check the number of images here and throw an error if multiple image are passed

    if ip_hidden_states[0].shape[1] > 1: 
            raise ValueError("...."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think that? If you perform that check, you will remove all the instant lora functionality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's only when mask is not None though - you can still use multiple images without mask
and it's only based on the understanding that we can only use one image/one mask/one ip-adapter when we use mask, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it works with multiple images for sure we don't need this!

Copy link
Member

@asomoza asomoza Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works with multiple images, I tested it, so the only check should be that the number of masks matches the number of ip adapters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we covering these cases in the tests?

Comment on lines 2217 to 2240
seq_len = current_ip_hidden_states.shape[1]
o_h = masks[0].shape[1]
o_w = masks[0].shape[2]
ratio = o_w / o_h
mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
mask_w = seq_len // mask_h

if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
mask_downsample = F.interpolate(
torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic"
).squeeze(0)

if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
if mask_downsample.shape[0] > batch_size:
mask_downsample = mask_downsample[:batch_size, :, :]

mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(
1, 1, current_ip_hidden_states.shape[-1]
)

mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move this code to VaeImageProcessor, https://github.com/huggingface/diffusers/blob/main/src/diffusers/image_processor.py

maybe we can create a IPAdapterMaskProcessor(VaeImageProcessor) and add a downsample method

Suggested change
seq_len = current_ip_hidden_states.shape[1]
o_h = masks[0].shape[1]
o_w = masks[0].shape[2]
ratio = o_w / o_h
mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
mask_w = seq_len // mask_h
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
mask_downsample = F.interpolate(
torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic"
).squeeze(0)
if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
if mask_downsample.shape[0] > batch_size:
mask_downsample = mask_downsample[:batch_size, :, :]
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(
1, 1, current_ip_hidden_states.shape[-1]
)
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)
mask_downsample = IPAdapterMaskProcessor.downsample(mask, seq_length, batch_size)
mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device)

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Feb 7, 2024

so the remaining item is:

  1. the resizing [WIP] IP-Adapter attention masking  #6847 (comment)
  2. refactor the code

I have just added padding to fix the resizing bug, I see output is still good.
maybe it is better to recommend using masks with aspect ratio equal or very close to that of the output images, but avoiding generating errors if there is a mismatch.
I will finish refactoring as suggested after work :)

fabiorigano and others added 2 commits February 7, 2024 21:01
Co-authored-by: YiYi Xu <yixu310@gmail.com>
- Move downsampling code to downsample method
- Add process method that internally calls preprocess
@fabiorigano
Copy link
Contributor Author

fabiorigano commented Feb 7, 2024

Updated snippet to run inference:

from diffusers import AutoPipelineForText2Image, DDIMScheduler
import torch
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection
from diffusers.image_processor import IPAdapterMaskProcessor

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")

processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2])

ip_images =[[image1], [image2]]

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
pipeline.set_ip_adapter_scale([0.7, 0.7])
generator = torch.Generator(device="cpu").manual_seed(1)
num_images=1

images = pipeline(
    prompt="2 girls",
    ip_adapter_image=ip_images,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", 
    num_inference_steps=20, num_images_per_prompt=num_images, 
    generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
).images

Output:
p1_0

@fabiorigano
Copy link
Contributor Author

@yiyixuxu can you give a look when you have time please?

I added a test, while I didn't touch documentation because I saw there is a big refactoring going on right now

thanks :)

@fabiorigano fabiorigano changed the title [WIP] IP-Adapter attention masking IP-Adapter attention masking Feb 8, 2024
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh looking great!
left a few nits

src/diffusers/image_processor.py Outdated Show resolved Hide resolved
src/diffusers/image_processor.py Outdated Show resolved Hide resolved
src/diffusers/image_processor.py Outdated Show resolved Hide resolved
src/diffusers/image_processor.py Outdated Show resolved Hide resolved
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 9, 2024

also I think if we merge this #6915 (comment)
we won't need to add the additional ip_adapter_mask argument to the default attention processors

and also this test PR is relevant too, we will wait for it to merge and update the test #6888

the doc can be added later

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Feb 9, 2024

cc @asomoza
can you do a final review too?

@yiyixuxu yiyixuxu requested review from DN6 and sayakpaul February 9, 2024 02:46
@sayakpaul
Copy link
Member

That was because the HF Hub was down. I just rebased your PR with the latest main. Let's see :)

@fabiorigano
Copy link
Contributor Author

I want to leave a comment on mask preprocessing for future documentation (maybe Sayak was asking here #6847 (comment))
We have several options:

  1. masks and output image have the same aspect ratio: preprocessing can be done with MaskImageProcessor.preprocess as in this example IP-Adapter attention masking  #6847 (comment) without further changes

  2. masks and output image don't have the same aspect ratio:

    a. (recommended) preprocessing can be done with MaskImageProcessor.preprocess but height and width of the output image must be passed as arguments like this: processor.preprocess([mask1, mask2], height=output_height, width=output_width). Masks will be stretched to fit the target shape

    b. if the aspect ratios are not very different, preprocessing can be done as in 1. Masks will preserve their original aspect ratio during downsampling, but some extra padding will be added if downsampling size doesn't match the number of queries in the attention. When apect ratios of masks and output image are very different, this option is not recommended.

@asomoza for completeness I tested your example in #6847 (comment). I leave here the change to the code and the resulting image:

# masks have both shape: (1152, 896) W,H
output_height = 1024
output_width = 1024
processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
# masks have now shape: [2, 1, 1024, 1024] Num_Images, C, H, W 

p1_0

thanks everyone who contributed here!

@sayakpaul
Copy link
Member

I think we should go with the simplest reasonable alternative from our code in the default setting and document the rest of the gotchas very clearly so that users can avail all the options. Goes well with our philosophy of being "simple over easy" as well.

What do y'all think?

@fabiorigano
Copy link
Contributor Author

that sounds good to me

should we wait for PR #6897 to be merged?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great to me!

@yiyixuxu
Copy link
Collaborator

@sayakpaul feel free to merge this if you're happy about it!

@sayakpaul
Copy link
Member

WDYT about adding a section about attention masking in the https://huggingface.co/docs/diffusers/main/en/using-diffusers/ip_adapter doc?

@sayakpaul
Copy link
Member

Once that's done and reviewed, let's ship this 🚀

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Just one small request related to testing.

@fabiorigano
Copy link
Contributor Author

thanks @DN6, I added one more test

@yiyixuxu could you load this #6847 (comment) output image to your HF testing-images repository as "ip_adapter_masking_output.png" please? thank you

I can also load it to the documentation-images repository if it is faster

@sayakpaul
Copy link
Member

thanks @DN6, I added one more test

@yiyixuxu could you load this #6847 (comment) output image to your HF testing-images repository as "ip_adapter_masking_output.png" please? thank you

I can also load it to the documentation-images repository if it is faster

Could you let me know which images do you want to see uploaded on the Hub? I can do that quickly :)

@fabiorigano
Copy link
Contributor Author

fabiorigano commented Feb 16, 2024

@sayakpaul this one. thank you very much! it is the one obtained with seed = 0 (see docs)

ip_adapter_masking_output

@sayakpaul
Copy link
Member

Here you go: https://huggingface.co/datasets/huggingface/documentation-images/blob/main/diffusers/ip_adapter_attention_mask_result_seed_0.png :)

@dhealy05
Copy link

hey folks, tried out the branch, it works for me except when I call pipe.unload_ip_adapter() prior to loading the weights.

if load them initially -- success

if i load other weights, unload, and reload, i get: RuntimeError: mat1 and mat2 shapes cannot be multiplied (514x1664 and 1280x1280)

not sure if this is in scope here but just thought i would mention before it's merged!

@sayakpaul
Copy link
Member

We are going to merge it. I welcome you to open a new issue with a fully reproducible code snippet afterward.

@yiyixuxu feel free to merge if this looks like a go to you.

@fabiorigano
Copy link
Contributor Author

@dhealy05 this happens because you are using a SDXL pipeline and not reloading the correct image encoder

When using IP-Adapters for SDXL, you must first load the CLIPVisionModelWithProjection image encoder from the "models/image_encoder" folder of "h94/IP-Adapter".

Calling pipeline.unload_ip_adapter() removes both IP-Adapter weights and image encoder from the pipeline.

This leads to the issue: by default, if you don't load an image encoder into the pipeline, it is searched in the IP-Adapter folder. In the case of the IP-Adapters for SDXL, this folder is "sdxl_models/image_encoder" and not "models/image_encoder".

To solve the problem, you need to reload the image encoder as follows:

# define the image_encoder
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", 
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
)

# define your pipeline
pipeline = AutoPipelineForText2Image.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    image_encoder=image_encoder
)
pipeline.to("cuda")

# load your IP-Adapters for SDXL (first time)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]*2)

# do your inference

#unload IP-Adapters
pipeline.unload_ip_adapter()

# **reload image encoder in the pipeline (very important)**
pipeline.image_encoder=image_encoder

# load your IP-Adapters for SDXL  (second time)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]*2)

# do your inference

@dhealy05
Copy link

@fabiorigano that was it, thank you !!

@yiyixuxu yiyixuxu merged commit eba7e7a into huggingface:main Feb 19, 2024
13 checks passed
@sayakpaul
Copy link
Member

Excellent work, @fabiorigano. Also, hat-tip to @asomoza for all the helpful suggestions and testing!

@cubiq
Copy link

cubiq commented Mar 19, 2024

I had a quick look at the code, sorry if a bit late

                mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio)))
                mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0)
                mask_w = seq_len // mask_h

not sure why using torch.sqrt instead of math.sqrt. Feels like a waste. Also rounding mask_h in the first line I believe might introduce rounding errors.

torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)

I would switch to "bilinear" that is faster. I don't think the mask would need bicubic anyway.

                if mask_downsample.shape[0] < batch_size:
                    mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
                if mask_downsample.shape[0] > batch_size:
                    mask_downsample = mask_downsample[:batch_size, :, :]

use if...elif

from mask_downsample.repeat(batch_size, 1, 1) I assume you allow only 1 mask? If that is the case you should trim the tensor before downsampling otherwise you are wasting resources. If you allow only 1 mask, it's also unlikely that the second statement is ever true.

In comfyui I allow sending multiple masks that are applied one per latent in the batch. But there's no such logic here.

                if mask_h * mask_w < seq_len:
                    mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0)
                if mask_h * mask_w > seq_len:
                    mask_downsample = mask_downsample[:, :seq_len]

use if...elif

is elif somewhat discouraged in diffusers?

@fabiorigano
Copy link
Contributor Author

hi @cubiq I think you are looking at an old implementation, here is the merged version

class IPAdapterMaskProcessor(VaeImageProcessor):

@cubiq
Copy link

cubiq commented Mar 19, 2024

@fabiorigano
oh okay sorry 😄 some of the remarks still stand

mask_h = int(math.sqrt(num_queries / ratio))
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
mask_w = num_queries // mask_h

don't INT the first mask_h (I think might introduce more rounding errors). Or don't int it again in the second line.

I would use bilinear instead of bicubic.

Use if/elif.

sorry if it took me so long to reply

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[IP-Adapter] Adding IP-adapter masking feature
8 participants