From db71a80ff63e2fac275978dcc8782fffb86a3c23 Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Tue, 21 Nov 2023 19:52:20 +0530 Subject: [PATCH] Addition of new callbacks to controlnets (#5812) * add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py * update callbacks * fix repeated kwarg * update --------- Co-authored-by: Patrick von Platen --- .../controlnet/pipeline_controlnet.py | 78 +++++++++-- .../controlnet/pipeline_controlnet_img2img.py | 115 ++++++++++++---- .../controlnet/pipeline_controlnet_inpaint.py | 119 +++++++++++++---- .../pipeline_controlnet_inpaint_sd_xl.py | 123 +++++++++++++----- .../controlnet/pipeline_controlnet_sd_xl.py | 91 ++++++++++--- .../pipeline_controlnet_sd_xl_img2img.py | 117 +++++++++++++---- 6 files changed, 510 insertions(+), 133 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index db17e2b7c181e..4f625304fdf90 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -130,6 +130,7 @@ class StableDiffusionControlNetPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -485,15 +486,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -760,6 +767,10 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 def guidance_scale(self): return self._guidance_scale + @property + def clip_skip(self): + return self._clip_skip + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -767,6 +778,14 @@ def guidance_scale(self): def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -786,14 +805,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -868,6 +888,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -878,6 +907,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -903,9 +949,12 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -929,7 +978,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, @@ -940,7 +989,7 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -988,6 +1037,7 @@ def __call__( # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1078,7 +1128,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, @@ -1087,11 +1137,21 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 24e4050939c8f..8945bd3d9c81a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -164,6 +164,7 @@ class StableDiffusionControlNetImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -519,15 +520,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -808,6 +815,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -829,14 +859,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -892,12 +923,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -915,6 +940,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -925,6 +959,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -950,8 +1001,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -961,10 +1017,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -978,23 +1030,23 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image @@ -1010,7 +1062,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1025,7 +1077,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1039,6 +1091,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) # 6. Prepare latent variables latents = self.prepare_latents( @@ -1068,11 +1121,11 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1099,7 +1152,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1111,20 +1164,30 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 99c72d21e22ef..9e2e428eaf91a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -286,6 +286,7 @@ class StableDiffusionControlNetInpaintPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -656,18 +657,24 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -999,6 +1006,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1021,14 +1051,15 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.5, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -1101,12 +1132,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -1124,6 +1149,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1134,6 +1168,23 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1161,8 +1212,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1172,10 +1228,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -1189,23 +1241,23 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image @@ -1218,7 +1270,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1233,7 +1285,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1261,6 +1313,7 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -1297,7 +1350,7 @@ def __call__( prompt_embeds.dtype, device, generator, - do_classifier_free_guidance, + self.do_classifier_free_guidance, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -1317,11 +1370,11 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1348,7 +1401,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1363,14 +1416,14 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -1379,7 +1432,7 @@ def __call__( if num_channels_unet == 4: init_latents_proper = image_latents - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask @@ -1392,6 +1445,16 @@ def __call__( latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 856ebbe6bbb51..3e5cba79f50b3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -34,6 +34,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + deprecate, is_invisible_watermark_available, logging, replace_example_docstring, @@ -167,6 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -555,6 +557,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -565,14 +568,20 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -1008,6 +1017,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1039,8 +1071,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, @@ -1053,6 +1083,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -1147,12 +1180,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -1182,6 +1209,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1190,6 +1226,23 @@ def __call__( [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1237,8 +1290,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1248,17 +1306,13 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( @@ -1271,7 +1325,7 @@ def __call__( prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, @@ -1279,7 +1333,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. set timesteps @@ -1300,6 +1354,7 @@ def denoising_value_valid(dnv): latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) # 5. Preprocess mask and image - resizes image and mask w.r.t height and width # 5.1 Prepare init image @@ -1316,7 +1371,7 @@ def denoising_value_valid(dnv): num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): @@ -1331,7 +1386,7 @@ def denoising_value_valid(dnv): num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1385,7 +1440,7 @@ def denoising_value_valid(dnv): prompt_embeds.dtype, device, generator, - do_classifier_free_guidance, + self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match @@ -1446,7 +1501,7 @@ def denoising_value_valid(dnv): ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) @@ -1483,7 +1538,7 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1491,7 +1546,7 @@ def denoising_value_valid(dnv): added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1528,7 +1583,7 @@ def denoising_value_valid(dnv): return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1543,7 +1598,7 @@ def denoising_value_valid(dnv): latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1551,11 +1606,11 @@ def denoising_value_valid(dnv): )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if do_classifier_free_guidance and guidance_rescale > 0.0: + if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) @@ -1564,7 +1619,7 @@ def denoising_value_valid(dnv): if num_channels_unet == 4: init_latents_proper = image_latents - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask @@ -1577,6 +1632,16 @@ def denoising_value_valid(dnv): latents = (1 - init_mask) * init_latents_proper + init_mask * latents + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index e248a48f8ed7b..c1efd8aaa3970 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -35,7 +35,14 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -143,6 +150,7 @@ class StableDiffusionXLControlNetPipeline( # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -487,15 +495,21 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -825,6 +839,10 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 def guidance_scale(self): return self._guidance_scale + @property + def clip_skip(self): + return self._clip_skip + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -832,6 +850,14 @@ def guidance_scale(self): def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -855,8 +881,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, @@ -869,6 +893,9 @@ def __call__( negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -937,12 +964,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -989,6 +1010,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -997,6 +1027,23 @@ def __call__( If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1026,9 +1073,12 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1052,7 +1102,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1072,7 +1122,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. Prepare image @@ -1115,6 +1165,7 @@ def __call__( # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1254,7 +1305,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1269,6 +1320,16 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 3926eba33024d..4fccd6a91b0f9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -37,6 +37,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -195,6 +196,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -543,6 +545,7 @@ def check_inputs( controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -553,14 +556,20 @@ def check_inputs( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -951,6 +960,29 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -976,8 +1008,6 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, @@ -992,6 +1022,9 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -1077,12 +1110,6 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -1138,6 +1165,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. Examples: @@ -1146,6 +1182,23 @@ def __call__( [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` containing the output images. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance @@ -1177,8 +1230,13 @@ def __call__( controlnet_conditioning_scale, control_guidance_start, control_guidance_end, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -1188,10 +1246,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -1205,7 +1259,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, @@ -1217,7 +1271,7 @@ def __call__( prompt_2, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, @@ -1225,7 +1279,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # 4. Prepare image and controlnet_conditioning_image @@ -1240,7 +1294,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] @@ -1256,7 +1310,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1271,6 +1325,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) # 6. Prepare latent variables latents = self.prepare_latents( @@ -1328,7 +1383,7 @@ def __call__( ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) @@ -1343,13 +1398,13 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1382,7 +1437,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1394,7 +1449,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, @@ -1402,13 +1457,23 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update()