diff --git a/examples/community/README.md b/examples/community/README.md index 6dbac2e16d7a..97b1b037113a 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2287,9 +2287,9 @@ Here's a full example for `ReplaceEdit``: import torch import numpy as np import matplotlib.pyplot as plt -from diffusers.pipelines import Prompt2PromptPipeline +from diffusers import DiffusionPipeline -pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda") +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda") prompts = ["A turtle playing with a ball", "A monkey playing with a ball"] diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py index 200b5571ef70..541d93b69b68 100644 --- a/examples/community/pipeline_prompt2prompt.py +++ b/examples/community/pipeline_prompt2prompt.py @@ -21,8 +21,11 @@ import torch import torch.nn.functional as F -from ...src.diffusers.models.attention import Attention -from ...src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput +from diffusers.models.attention import Attention +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, +) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg @@ -165,7 +168,11 @@ def __call__( """ self.controller = create_controller( - prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device + prompt, + cross_attention_kwargs, + num_inference_steps, + tokenizer=self.tokenizer, + device=self.device, ) self.register_attention_control(self.controller) # add attention controller @@ -287,7 +294,7 @@ def register_attention_control(self, controller): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): - None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + (None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim) if name.startswith("mid_block"): self.unet.config.block_out_channels[-1] place_in_unet = "mid" @@ -314,7 +321,13 @@ def __init__(self, controller, place_in_unet): self.controller = controller self.place_in_unet = place_in_unet - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -346,7 +359,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a def create_controller( - prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device + prompts: List[str], + cross_attention_kwargs: Dict, + num_inference_steps: int, + tokenizer, + device, ) -> AttentionControl: edit_type = cross_attention_kwargs.get("edit_type", None) local_blend_words = cross_attention_kwargs.get("local_blend_words", None) @@ -358,27 +375,49 @@ def create_controller( # only replace if edit_type == "replace" and local_blend_words is None: return AttentionReplace( - prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, ) # replace + localblend if edit_type == "replace" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device) return AttentionReplace( - prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + lb, + tokenizer=tokenizer, + device=device, ) # only refine if edit_type == "refine" and local_blend_words is None: return AttentionRefine( - prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, ) # refine + localblend if edit_type == "refine" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device) return AttentionRefine( - prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + lb, + tokenizer=tokenizer, + device=device, ) # reweight @@ -447,7 +486,14 @@ def forward(self, attn, is_cross: bool, place_in_unet: str): class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): - return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + return { + "down_cross": [], + "mid_cross": [], + "up_cross": [], + "down_self": [], + "mid_self": [], + "up_self": [], + } def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" @@ -497,7 +543,13 @@ def __call__(self, x_t, attention_store): return x_t def __init__( - self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, max_num_words=77 + self, + prompts: List[str], + words: [List[List[str]]], + tokenizer, + device, + threshold=0.3, + max_num_words=77, ): self.max_num_words = 77 @@ -588,7 +640,13 @@ def __init__( device=None, ): super(AttentionReplace, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) @@ -610,7 +668,13 @@ def __init__( device=None, ): super(AttentionRefine, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) @@ -637,7 +701,13 @@ def __init__( device=None, ): super(AttentionReweight, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.equalizer = equalizer.to(self.device) self.prev_controller = controller @@ -645,7 +715,10 @@ def __init__( ### util functions for all Edits def update_alpha_time_word( - alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None + alpha, + bounds: Union[float, Tuple[float, float]], + prompt_ind: int, + word_inds: Optional[torch.Tensor] = None, ): if isinstance(bounds, float): bounds = 0, bounds @@ -659,7 +732,11 @@ def update_alpha_time_word( def get_time_words_attention_alpha( - prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 + prompts, + num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, + max_num_words=77, ): if not isinstance(cross_replace_steps, dict): cross_replace_steps = {"default_": cross_replace_steps} @@ -750,7 +827,10 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77): ### util functions for ReweightEdit def get_equalizer( - text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer + text: str, + word_select: Union[int, Tuple[int, ...]], + values: Union[List[float], Tuple[float, ...]], + tokenizer, ): if isinstance(word_select, (int, str)): word_select = (word_select,)