Skip to content

Commit

Permalink
Merge branch 'main' into ip-adapter-test-mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w authored Feb 16, 2024
2 parents fdcbaff + 777063e commit d618540
Show file tree
Hide file tree
Showing 66 changed files with 249 additions and 166 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install pandas
python -m pip install pandas peft
- name: Environment
run: |
python utils/print_env.py
Expand Down
4 changes: 2 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2287,9 +2287,9 @@ Here's a full example for `ReplaceEdit``:
import torch
import numpy as np
import matplotlib.pyplot as plt
from diffusers.pipelines import Prompt2PromptPipeline
from diffusers import DiffusionPipeline

pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")

prompts = ["A turtle playing with a ball",
"A monkey playing with a ball"]
Expand Down
4 changes: 2 additions & 2 deletions examples/community/ip_adapter_face_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -930,7 +930,7 @@ def encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
4 changes: 2 additions & 2 deletions examples/community/latent_consistency_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -477,7 +477,7 @@ def encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
4 changes: 2 additions & 2 deletions examples/community/llm_grounded_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -1391,7 +1391,7 @@ def encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
2 changes: 1 addition & 1 deletion examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def encode_prompt(

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
Expand Down
4 changes: 2 additions & 2 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -329,7 +329,7 @@ def encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
2 changes: 1 addition & 1 deletion examples/community/pipeline_demofusion_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def encode_prompt(

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
Expand Down
4 changes: 2 additions & 2 deletions examples/community/pipeline_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -304,7 +304,7 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
116 changes: 98 additions & 18 deletions examples/community/pipeline_prompt2prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import torch
import torch.nn.functional as F

from ...src.diffusers.models.attention import Attention
from ...src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput
from diffusers.models.attention import Attention
from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput,
)


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
Expand Down Expand Up @@ -165,7 +168,11 @@ def __call__(
"""

self.controller = create_controller(
prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device
prompt,
cross_attention_kwargs,
num_inference_steps,
tokenizer=self.tokenizer,
device=self.device,
)
self.register_attention_control(self.controller) # add attention controller

Expand Down Expand Up @@ -287,7 +294,7 @@ def register_attention_control(self, controller):
attn_procs = {}
cross_att_count = 0
for name in self.unet.attn_processors.keys():
None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
(None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim)
if name.startswith("mid_block"):
self.unet.config.block_out_channels[-1]
place_in_unet = "mid"
Expand All @@ -314,7 +321,13 @@ def __init__(self, controller, place_in_unet):
self.controller = controller
self.place_in_unet = place_in_unet

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

Expand Down Expand Up @@ -346,7 +359,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a


def create_controller(
prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device
prompts: List[str],
cross_attention_kwargs: Dict,
num_inference_steps: int,
tokenizer,
device,
) -> AttentionControl:
edit_type = cross_attention_kwargs.get("edit_type", None)
local_blend_words = cross_attention_kwargs.get("local_blend_words", None)
Expand All @@ -358,27 +375,49 @@ def create_controller(
# only replace
if edit_type == "replace" and local_blend_words is None:
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
)

# replace + localblend
if edit_type == "replace" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
)

# only refine
if edit_type == "refine" and local_blend_words is None:
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
)

# refine + localblend
if edit_type == "refine" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
)

# reweight
Expand Down Expand Up @@ -447,7 +486,14 @@ def forward(self, attn, is_cross: bool, place_in_unet: str):
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
return {
"down_cross": [],
"mid_cross": [],
"up_cross": [],
"down_self": [],
"mid_self": [],
"up_self": [],
}

def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
Expand Down Expand Up @@ -497,7 +543,13 @@ def __call__(self, x_t, attention_store):
return x_t

def __init__(
self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, max_num_words=77
self,
prompts: List[str],
words: [List[List[str]]],
tokenizer,
device,
threshold=0.3,
max_num_words=77,
):
self.max_num_words = 77

Expand Down Expand Up @@ -588,7 +640,13 @@ def __init__(
device=None,
):
super(AttentionReplace, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)

Expand All @@ -610,7 +668,13 @@ def __init__(
device=None,
):
super(AttentionRefine, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
Expand All @@ -637,15 +701,24 @@ def __init__(
device=None,
):
super(AttentionReweight, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.equalizer = equalizer.to(self.device)
self.prev_controller = controller


### util functions for all Edits
def update_alpha_time_word(
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
alpha,
bounds: Union[float, Tuple[float, float]],
prompt_ind: int,
word_inds: Optional[torch.Tensor] = None,
):
if isinstance(bounds, float):
bounds = 0, bounds
Expand All @@ -659,7 +732,11 @@ def update_alpha_time_word(


def get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
prompts,
num_steps,
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
tokenizer,
max_num_words=77,
):
if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps}
Expand Down Expand Up @@ -750,7 +827,10 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77):

### util functions for ReweightEdit
def get_equalizer(
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
text: str,
word_select: Union[int, Tuple[int, ...]],
values: Union[List[float], Tuple[float, ...]],
tokenizer,
):
if isinstance(word_select, (int, str)):
word_select = (word_select,)
Expand Down
2 changes: 1 addition & 1 deletion examples/community/pipeline_sdxl_style_aligned.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def encode_prompt(
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
Expand Down
4 changes: 2 additions & 2 deletions examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def encode_prompt(
batch_size = prompt_embeds.shape[0]

if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

Expand Down Expand Up @@ -332,7 +332,7 @@ def encode_prompt(
else:
uncond_tokens = negative_prompt

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def encode_prompt(
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

# textual inversion: procecss multi-vector tokens if necessary
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
Expand Down
Loading

0 comments on commit d618540

Please sign in to comment.