Skip to content

Commit

Permalink
support for torch2.0, update license
Browse files Browse the repository at this point in the history
  • Loading branch information
ResearcherXman committed Jan 24, 2024
1 parent fee9f5e commit abbb364
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 27 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
<a href='https://instantid.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
<a href='https://arxiv.org/abs/2401.07519'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
<a href='https://huggingface.co/papers/2401.07519'><img src='https://img.shields.io/static/v1?label=Paper&message=Huggingface&color=orange'></a>
<a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
[![Replicate](https://replicate.com/zsxkib/instant-id/badge)](https://replicate.com/zsxkib/instant-id)
<a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>

**InstantID : Zero-shot Identity-Preserving Generation in Seconds**

Expand All @@ -18,6 +17,9 @@ InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving
- [2023/12/11] 🔥 We launch the [project page](https://instantid.github.io/).

## Demos
<a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
[![Replicate](https://replicate.com/zsxkib/instant-id/badge)](https://replicate.com/zsxkib/instant-id)
[![ModelScope](https://img.shields.io/badge/modelscope-InstantID-blue)](https://modelscope.cn/studios/instantx/InstantID/summary)

### Stylized Synthesis

Expand Down Expand Up @@ -154,10 +156,9 @@ python gradio_demo/app.py
- For specific styles, choose corresponding base model makes differences.
- We have not supported multi-person yet, will only use the largest face as reference pose.

## Resources
## Community Resources

### Gradio Demo
- [Huggingface Space](https://huggingface.co/spaces/InstantX/InstantID)
- [instantid.org](https://instantid.org/)

### Replicate Demo
Expand All @@ -174,9 +175,10 @@ python gradio_demo/app.py
- Our work is highly inspired by [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [ControlNet](https://github.com/lllyasviel/ControlNet). Thanks for their great works!
- Thanks [ZHO-ZHO-ZHO](https://github.com/ZHO-ZHO-ZHO), [huxiuhan](https://github.com/huxiuhan), [sdbds](https://github.com/sdbds), [zsxkib](https://replicate.com/zsxkib) for their generous contributions.
- Thanks to the [HuggingFace](https://github.com/huggingface) gradio team for their free GPU support!
- Thanks to the [ModelScope](https://github.com/modelscope/modelscope) team for their free GPU support!

## Disclaimer
This project is released under [Apache License](https://github.com/InstantID/InstantID?tab=Apache-2.0-1-ov-file#readme) and aims to positively impact the field of AI-driven image generation. Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.
The code of InstantID is released under [Apache License](https://github.com/InstantID/InstantID?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. **However, both manual-downloading and auto-downloading face models from insightface are for non-commercial research purposes only** accoreding to their [license](https://github.com/deepinsight/insightface?tab=readme-ov-file#license). Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.

## Cite
If you find InstantID useful for your research and applications, please cite us using this BibTeX:
Expand Down
4 changes: 2 additions & 2 deletions gradio_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name,

# prompt
prompt = gr.Textbox(label="Prompt",
info="Give simple prompt is enough to achieve good face fedility",
info="Give simple prompt is enough to achieve good face fidelity",
placeholder="A photo of a person",
value="")

Expand All @@ -346,7 +346,7 @@ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name,

# strength
identitynet_strength_ratio = gr.Slider(
label="IdentityNet strength (for fedility)",
label="IdentityNet strength (for fidelity)",
minimum=0,
maximum=1.5,
step=0.05,
Expand Down
2 changes: 1 addition & 1 deletion gradio_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
diffusers==0.25.0
torch==2.0.0
torchvision==0.15.1
transformers==4.36.2
transformers==4.30.1
accelerate
safetensors
einops
Expand Down
148 changes: 130 additions & 18 deletions ip_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
xformers_available = False



class RegionControler(object):
def __init__(self) -> None:
self.prompt_image_conditioning = []
region_control = RegionControler()


class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
Expand Down Expand Up @@ -180,17 +173,6 @@ def __call__(
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

# region control
if len(region_control.prompt_image_conditioning) == 1:
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
if region_mask is not None:
h, w = region_mask.shape[:2]
ratio = (h * w / query.shape[1]) ** 0.5
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
else:
mask = torch.ones_like(ip_hidden_states)
ip_hidden_states = ip_hidden_states * mask

hidden_states = hidden_states + self.scale * ip_hidden_states

Expand Down Expand Up @@ -305,4 +287,134 @@ def __call__(

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states

class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""

def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens

self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + self.scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
5 changes: 4 additions & 1 deletion pipeline_stable_diffusion_xl_instantid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
from ip_adapter.resampler import Resampler
from ip_adapter.utils import is_torch2_available

from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
if is_torch2_available():
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down

0 comments on commit abbb364

Please sign in to comment.