Skip to content

Commit

Permalink
add multiple attention masks
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Nov 24, 2023
1 parent d6cfa11 commit ebd946f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
38 changes: 35 additions & 3 deletions IPAdapterPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def set_new_condition(self, weight, ipadapter, cond, uncond, dtype, number, weig
def __call__(self, n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
cond_or_uncond = extra_options["cond_or_uncond"]

with torch.autocast(device_type=self.device, dtype=self.dtype):
q = n
k = context_attn2
Expand Down Expand Up @@ -287,9 +288,40 @@ def __call__(self, n, context_attn2, value_attn2, extra_options):

if mask_h*mask_w == qs:
break

# check if using AnimateDiff and sliding context window
if (mask.shape[0] > 1 and hasattr(cond_or_uncond, "params") and cond_or_uncond.params["sub_idxs"] is not None):
# if mask length matches or exceeds full_length, just get sub_idx masks, resize, and continue
if mask.shape[0] >= cond_or_uncond.params["full_length"]:
mask_downsample = torch.Tensor(mask[cond_or_uncond.params["sub_idxs"]])
mask_downsample = F.interpolate(mask_downsample.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)
# otherwise, need to do more to get proper sub_idxs masks
else:
# first, resize to needed attention size (to save on needed memory for other operations)
mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)
# check if mask length matches full_length - if not, make it match
if mask_downsample.shape[0] < cond_or_uncond.params["full_length"]:
mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:].repeat((cond_or_uncond.params["full_length"]-mask_downsample.shape[0], 1, 1))), dim=0)
# if we have too many remove the excess (should not happen, but just in case)
if mask_downsample.shape[0] > cond_or_uncond.params["full_length"]:
mask_downsample = mask_downsample[:cond_or_uncond.params["full_length"]]
# now, select sub_idxs masks
mask_downsample = mask_downsample[cond_or_uncond.params["sub_idxs"]]
# otherwise, perform usual mask interpolation
else:
mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)

# if we don't have enough masks repeat the last one until we reach the right size
if mask_downsample.shape[0] < batch_prompt:
mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0)
# if we have too many remove the exceeding
elif mask_downsample.shape[0] > batch_prompt:
mask_downsample = mask_downsample[:batch_prompt, :, :]

mask_downsample = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(mask_h, mask_w), mode="bilinear").squeeze(0)
mask_downsample = mask_downsample.view(1, -1, 1).repeat(out.shape[0], 1, out.shape[2])
# repeat the masks
mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1)
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2])

out_ip = out_ip * mask_downsample

out = out + out_ip
Expand Down Expand Up @@ -410,7 +442,7 @@ def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None
work_model = model.clone()

if attn_mask is not None:
attn_mask = attn_mask.squeeze().to(self.device)
attn_mask = attn_mask.to(self.device)

patch_kwargs = {
"number": 0,
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ IPAdapter implementation that follows the ComfyUI way of doing things. The code

## Important updates

**2023/11/24**: Support for multiple attention masks.

**2023/11/23**: Small but important update: the new default location for the IPAdapter models is `ComfyUI/models/ipadapter`. **No panic**: the legacy `ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/models` location still works and nothing will break.

**2023/11/08**: Added [attention masking](#attention-masking).
Expand Down Expand Up @@ -119,7 +121,7 @@ IPAdapter offers an interesting model for a kind of "face swap" effect. [The wor

**Note:** there's a new `full-face` model available that's arguably better.

### Masking
### Masking (Inpainting)

The most effective way to apply the IPAdapter to a region is by an [inpainting workflow](./examples/IPAdapter_inpaint.json). Remeber to use a specific checkpoint for inpainting otherwise it won't work. Even if you are inpainting a face I find that the *IPAdapter-Plus* (not the *face* one), works best.

Expand Down Expand Up @@ -167,6 +169,8 @@ In the picture below I use two reference images masked one on the left and the o

<img src="./examples/masking.jpg" width="512" alt="masking" />

It is also possible to send a batch of masks that will be applied to a batch of latents, one per frame. The size should be the same but if needed some normalization will be performed to avoid errors. This feature also supports (experimentally) AnimateDiff including context sliding.

In the examples directory you'll find a couple of masking workflows: [simple](examples/IPAdapter_mask.json) and [two masks](examples/IPAdapter_2_masks.json).

## Troubleshooting
Expand Down

0 comments on commit ebd946f

Please sign in to comment.