From 738c9869572c45619056556673cc3b732455fac7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Sat, 24 Feb 2024 11:04:31 +0300 Subject: [PATCH] [`Refactor`] `StableDiffusionReferencePipeline` inheriting from `DiffusionPipeline` (#7071) Refactor StableDiffusionReferencePipeline to inherit from DiffusionPipeline rather than StableDiffusionPipeline --- .../community/stable_diffusion_reference.py | 694 +++++++++++++++++- 1 file changed, 675 insertions(+), 19 deletions(-) diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 1b02831ad19c..7af404c25b41 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -1,16 +1,31 @@ # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 +import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from diffusers import StableDiffusionPipeline +from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel +from diffusers.configuration_utils import FrozenDict, deprecate +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg -from diffusers.utils import PIL_INTERPOLATION, logging +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) from diffusers.utils.torch_utils import randn_tensor @@ -45,14 +60,182 @@ def torch_dfs(model: torch.nn.Module): + r""" + Performs a depth-first search on the given PyTorch model and returns a list of all its child modules. + + Args: + model (torch.nn.Module): The PyTorch model to perform the depth-first search on. + + Returns: + list: A list of all child modules of the given model. + """ result = [model] for child in model.children(): result += torch_dfs(child) return result -class StableDiffusionReferencePipeline(StableDiffusionPipeline): - def _default_height_width(self, height, width, image): +class StableDiffusionReferencePipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + r""" " + Pipeline for Stable Diffusion Reference. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate( + "skip_prk_steps not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 4: + logger.warning( + f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," + f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," + ". If you did not intend to modify" + " this behavior, please check whether you have loaded the right checkpoint." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _default_height_width( + self, + height: Optional[int], + width: Optional[int], + image: Union[PIL.Image.Image, torch.Tensor, List[PIL.Image.Image]], + ) -> Tuple[int, int]: + r""" + Calculate the default height and width for the given image. + + Args: + height (int or None): The desired height of the image. If None, the height will be determined based on the input image. + width (int or None): The desired width of the image. If None, the width will be determined based on the input image. + image (PIL.Image.Image or torch.Tensor or list[PIL.Image.Image]): The input image or a list of images. + + Returns: + Tuple[int, int]: A tuple containing the calculated height and width. + + """ # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. @@ -77,18 +260,430 @@ def _default_height_width(self, height, width, image): return height, width + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Optional[Union[str, List[str]]], + height: int, + width: int, + callback_steps: Optional[int], + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[torch.Tensor] = None, + ip_adapter_image_embeds: Optional[torch.FloatTensor] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ) -> None: + """ + Check the validity of the input arguments for the diffusion model. + + Args: + prompt (Optional[Union[str, List[str]]]): The prompt text or list of prompt texts. + height (int): The height of the input image. + width (int): The width of the input image. + callback_steps (Optional[int]): The number of steps to perform the callback on. + negative_prompt (Optional[str]): The negative prompt text. + prompt_embeds (Optional[torch.FloatTensor]): The prompt embeddings. + negative_prompt_embeds (Optional[torch.FloatTensor]): The negative prompt embeddings. + ip_adapter_image (Optional[torch.Tensor]): The input adapter image. + ip_adapter_image_embeds (Optional[torch.FloatTensor]): The input adapter image embeddings. + callback_on_step_end_tensor_inputs (Optional[List[str]]): The list of tensor inputs to perform the callback on. + + Raises: + ValueError: If `height` or `width` is not divisible by 8. + ValueError: If `callback_steps` is not a positive integer. + ValueError: If `callback_on_step_end_tensor_inputs` contains invalid tensor inputs. + ValueError: If both `prompt` and `prompt_embeds` are provided. + ValueError: If neither `prompt` nor `prompt_embeds` are provided. + ValueError: If `prompt` is not of type `str` or `list`. + ValueError: If both `negative_prompt` and `negative_prompt_embeds` are provided. + ValueError: If both `prompt_embeds` and `negative_prompt_embeds` are provided and have different shapes. + ValueError: If both `ip_adapter_image` and `ip_adapter_image_embeds` are provided. + + Returns: + None + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ) -> torch.FloatTensor: + r""" + Encodes the prompt into embeddings. + + Args: + prompt (Union[str, List[str]]): The prompt text or a list of prompt texts. + device (torch.device): The device to use for encoding. + num_images_per_prompt (int): The number of images per prompt. + do_classifier_free_guidance (bool): Whether to use classifier-free guidance. + negative_prompt (Optional[Union[str, List[str]]], optional): The negative prompt text or a list of negative prompt texts. Defaults to None. + prompt_embeds (Optional[torch.FloatTensor], optional): The prompt embeddings. Defaults to None. + negative_prompt_embeds (Optional[torch.FloatTensor], optional): The negative prompt embeddings. Defaults to None. + lora_scale (Optional[float], optional): The LoRA scale. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + torch.FloatTensor: The encoded prompt embeddings. + """ + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt: Optional[str], + device: torch.device, + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ) -> torch.FloatTensor: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Union[torch.Generator, List[torch.Generator]], + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Prepare the latent vectors for diffusion. + + Args: + batch_size (int): The number of samples in the batch. + num_channels_latents (int): The number of channels in the latent vectors. + height (int): The height of the latent vectors. + width (int): The width of the latent vectors. + dtype (torch.dtype): The data type of the latent vectors. + device (torch.device): The device to place the latent vectors on. + generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation. + latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated. + + Returns: + torch.Tensor: The prepared latent vectors. + """ + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs( + self, generator: Union[torch.Generator, List[torch.Generator]], eta: float + ) -> Dict[str, Any]: + r""" + Prepare extra keyword arguments for the scheduler step. + + Args: + generator (Union[torch.Generator, List[torch.Generator]]): The generator used for sampling. + eta (float): The value of eta (η) used with the DDIMScheduler. Should be between 0 and 1. + + Returns: + Dict[str, Any]: A dictionary containing the extra keyword arguments for the scheduler step. + """ + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + def prepare_image( self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): + image: Union[torch.Tensor, PIL.Image.Image, List[Union[torch.Tensor, PIL.Image.Image]]], + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, + ) -> torch.Tensor: + r""" + Prepares the input image for processing. + + Args: + image (torch.Tensor or PIL.Image.Image or list): The input image(s). + width (int): The desired width of the image. + height (int): The desired height of the image. + batch_size (int): The batch size for processing. + num_images_per_prompt (int): The number of images per prompt. + device (torch.device): The device to use for processing. + dtype (torch.dtype): The data type of the image. + do_classifier_free_guidance (bool, optional): Whether to perform classifier-free guidance. Defaults to False. + guess_mode (bool, optional): Whether to use guess mode. Defaults to False. + + Returns: + torch.Tensor: The prepared image for processing. + """ if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): image = [image] @@ -130,7 +725,29 @@ def prepare_image( return image - def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + def prepare_ref_latents( + self, + refimage: torch.Tensor, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + generator: Union[int, List[int]], + do_classifier_free_guidance: bool, + ) -> torch.Tensor: + r""" + Prepares reference latents for generating images. + + Args: + refimage (torch.Tensor): The reference image. + batch_size (int): The desired batch size. + dtype (torch.dtype): The data type of the tensors. + device (torch.device): The device to perform computations on. + generator (int or list): The generator index or a list of generator indices. + do_classifier_free_guidance (bool): Whether to use classifier-free guidance. + + Returns: + torch.Tensor: The prepared reference latents. + """ refimage = refimage.to(device=device, dtype=dtype) # encode the mask image into latents space so we can concatenate it to the latents @@ -158,6 +775,35 @@ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) return ref_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker( + self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype + ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: + r""" + Runs the safety checker on the given image. + + Args: + image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked. + device (torch.device): The device to run the safety checker on. + dtype (torch.dtype): The data type of the input image. + + Returns: + (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and + a boolean indicating whether the image has a NSFW (Not Safe for Work) concept. + """ + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + @torch.no_grad() def __call__( self, @@ -538,7 +1184,12 @@ def hack_CrossAttnDownBlock2D_forward( return hidden_states, output_states - def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): + def hacked_DownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + **kwargs: Any, + ) -> Tuple[torch.FloatTensor, ...]: eps = 1e-6 output_states = () @@ -588,7 +1239,7 @@ def hacked_CrossAttnUpBlock2D_forward( upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ): + ) -> torch.FloatTensor: eps = 1e-6 # TODO(Patrick, William) - attention mask is not used for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): @@ -635,8 +1286,13 @@ def hacked_CrossAttnUpBlock2D_forward( return hidden_states def hacked_UpBlock2D_forward( - self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs - ): + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + **kwargs: Any, + ) -> torch.FloatTensor: eps = 1e-6 for i, resnet in enumerate(self.resnets): # pop res hidden states