diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9022ca669f3f..af3922f5f0c5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -318,6 +318,8 @@ title: Semantic Guidance - local: api/pipelines/shap_e title: Shap-E + - local: api/pipelines/stable_cascade + title: Stable Cascade - sections: - local: api/pipelines/stable_diffusion/overview title: Overview diff --git a/docs/source/en/api/pipelines/stable_cascade.md b/docs/source/en/api/pipelines/stable_cascade.md new file mode 100644 index 000000000000..37df68bb03fd --- /dev/null +++ b/docs/source/en/api/pipelines/stable_cascade.md @@ -0,0 +1,88 @@ + + +# Stable Cascade + +This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main +difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this +important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes. +How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being +encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a +1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the +highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable +Diffusion 1.5. + +Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions +like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well. + +The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade). + +## Model Overview +Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images, +hence the name "Stable Cascade". + +Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion. +However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a +spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves +a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the +image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible +for generating the small 24 x 24 latents given a text prompt. + +## Uses + +### Direct Use + +The model is intended for research purposes for now. Possible research areas and tasks include + +- Research on generative models. +- Safe deployment of models which have the potential to generate harmful content. +- Probing and understanding the limitations and biases of generative models. +- Generation of artworks and use in design and other artistic processes. +- Applications in educational or creative tools. + +Excluded uses are described below. + +### Out-of-Scope Use + +The model was not trained to be factual or true representations of people or events, +and therefore using the model to generate such content is out-of-scope for the abilities of this model. +The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy). + +## Limitations and Bias + +### Limitations +- Faces and people in general may not be generated properly. +- The autoencoding part of the model is lossy. + + +## StableCascadeCombinedPipeline + +[[autodoc]] StableCascadeCombinedPipeline + - all + - __call__ + +## StableCascadePriorPipeline + +[[autodoc]] StableCascadePriorPipeline + - all + - __call__ + +## StableCascadePriorPipelineOutput + +[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput + +## StableCascadeDecoderPipeline + +[[autodoc]] StableCascadeDecoderPipeline + - all + - __call__ + diff --git a/scripts/convert_stable_cascade.py b/scripts/convert_stable_cascade.py new file mode 100644 index 000000000000..704a1f4e31a6 --- /dev/null +++ b/scripts/convert_stable_cascade.py @@ -0,0 +1,215 @@ +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +import argparse + +import accelerate +import torch +from safetensors.torch import load_file +from transformers import ( + AutoTokenizer, + CLIPConfig, + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPVisionModelWithProjection, +) + +from diffusers import ( + DDPMWuerstchenScheduler, + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, +) +from diffusers.models import StableCascadeUNet +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.pipelines.wuerstchen import PaellaVQModel + + +parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") +parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights") +parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file") +parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file") +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") +parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to") +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") + +args = parser.parse_args() +model_path = args.model_path + +device = "cpu" + +# set paths to model weights +prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" +decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" + +# Clip Text encoder and tokenizer +config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") +config.text_config.projection_dim = config.projection_dim +text_encoder = CLIPTextModelWithProjection.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config +) +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + +# image processor +feature_extractor = CLIPImageProcessor() +image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + +# Prior +if args.use_safetensors: + orig_state_dict = load_file(prior_checkpoint_path, device=device) +else: + orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) + +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] + + +with accelerate.init_empty_weights(): + prior_model = StableCascadeUNet( + in_channels=16, + out_channels=16, + timestep_ratio_embedding_dim=64, + patch_size=1, + conditioning_dim=2048, + block_out_channels=[2048, 2048], + num_attention_heads=[32, 32], + down_num_layers_per_block=[8, 24], + up_num_layers_per_block=[24, 8], + down_blocks_repeat_mappers=[1, 1], + up_blocks_repeat_mappers=[1, 1], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_in_channels=1280, + clip_text_pooled_in_channels=1280, + clip_image_in_channels=768, + clip_seq=4, + kernel_size=3, + dropout=[0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca", "crp"], + switch_level=[False], + ) +load_model_dict_into_meta(prior_model, state_dict) + +# scheduler for prior and decoder +scheduler = DDPMWuerstchenScheduler() + +# Prior pipeline +prior_pipeline = StableCascadePriorPipeline( + prior=prior_model, + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + scheduler=scheduler, + feature_extractor=feature_extractor, +) +prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub) + +# Decoder +if args.use_safetensors: + orig_state_dict = load_file(decoder_checkpoint_path, device=device) +else: + orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) + +state_dict = {} +for key in orig_state_dict.keys(): + if key.endswith("in_proj_weight"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] + elif key.endswith("in_proj_bias"): + weights = orig_state_dict[key].chunk(3, 0) + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] + elif key.endswith("out_proj.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights + elif key.endswith("out_proj.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights + # rename clip_mapper to clip_txt_pooled_mapper + elif key.endswith("clip_mapper.weight"): + weights = orig_state_dict[key] + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights + elif key.endswith("clip_mapper.bias"): + weights = orig_state_dict[key] + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights + else: + state_dict[key] = orig_state_dict[key] + +with accelerate.init_empty_weights(): + decoder = StableCascadeUNet( + in_channels=4, + out_channels=4, + timestep_ratio_embedding_dim=64, + patch_size=2, + conditioning_dim=1280, + block_out_channels=[320, 640, 1280, 1280], + down_num_layers_per_block=[2, 6, 28, 6], + up_num_layers_per_block=[6, 28, 6, 2], + down_blocks_repeat_mappers=[1, 1, 1, 1], + up_blocks_repeat_mappers=[3, 3, 2, 2], + num_attention_heads=[0, 0, 20, 20], + block_types_per_layer=[ + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + clip_text_pooled_in_channels=1280, + clip_seq=4, + effnet_in_channels=16, + pixel_mapper_in_channels=3, + kernel_size=3, + dropout=[0, 0, 0.1, 0.1], + self_attn=True, + timestep_conditioning_type=["sca"], + ) +load_model_dict_into_meta(decoder, state_dict) + +# VQGAN from Wuerstchen-V2 +vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") + +# Decoder pipeline +decoder_pipeline = StableCascadeDecoderPipeline( + decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler +) +decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub) + +# Stable Cascade combined pipeline +stable_cascade_pipeline = StableCascadeCombinedPipeline( + # Decoder + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqmodel, + # Prior + prior_text_encoder=text_encoder, + prior_tokenizer=tokenizer, + prior_prior=prior_model, + prior_scheduler=scheduler, + prior_image_encoder=image_encoder, + prior_feature_extractor=feature_extractor, +) +stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 165bbf263b29..05dae1d38364 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -86,6 +86,7 @@ "MotionAdapter", "MultiAdapter", "PriorTransformer", + "StableCascadeUNet", "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", @@ -258,6 +259,9 @@ "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", + "StableCascadeCombinedPipeline", + "StableCascadeDecoderPipeline", + "StableCascadePriorPipeline", "StableDiffusionAdapterPipeline", "StableDiffusionAttendAndExcitePipeline", "StableDiffusionControlNetImg2ImgPipeline", @@ -624,6 +628,9 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, StableDiffusionAdapterPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4513d23adf52..da77e4450c86 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -47,6 +47,7 @@ _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"] _import_structure["vq_model"] = ["VQModel"] @@ -80,6 +81,7 @@ I2VGenXLUNet, Kandinsky3UNet, MotionAdapter, + StableCascadeUNet, UNet1DModel, UNet2DConditionModel, UNet2DModel, diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py index d505fc11a4ff..9ef04fb62606 100644 --- a/src/diffusers/models/unets/__init__.py +++ b/src/diffusers/models/unets/__init__.py @@ -10,6 +10,7 @@ from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .unet_stable_cascade import StableCascadeUNet from .uvit_2d import UVit2DModel diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py new file mode 100644 index 000000000000..b8833779ba9f --- /dev/null +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -0,0 +1,609 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention_processor import Attention +from ..modeling_utils import ModelMixin + + +# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm +class SDCascadeLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class SDCascadeTimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=[]): + super().__init__() + linear_cls = nn.Linear + self.mapper = linear_cls(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + + +class SDCascadeResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class SDCascadeAttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + linear_cls = nn.Linear + + self.self_attn = self_attn + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x + + +class UpDownBlock2d(nn.Module): + def __init__(self, in_channels, out_channels, mode, enabled=True): + super().__init__() + if mode not in ["up", "down"]: + raise ValueError(f"{mode} not supported") + interpolation = ( + nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) + if enabled + else nn.Identity() + ) + mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +@dataclass +class StableCascadeUNetOutput(BaseOutput): + sample: torch.FloatTensor = None + + +class StableCascadeUNet(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + timestep_ratio_embedding_dim: int = 64, + patch_size: int = 1, + conditioning_dim: int = 2048, + block_out_channels: Tuple[int] = (2048, 2048), + num_attention_heads: Tuple[int] = (32, 32), + down_num_layers_per_block: Tuple[int] = (8, 24), + up_num_layers_per_block: Tuple[int] = (24, 8), + down_blocks_repeat_mappers: Optional[Tuple[int]] = ( + 1, + 1, + ), + up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1), + block_types_per_layer: Tuple[Tuple[str]] = ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_text_in_channels: Optional[int] = None, + clip_text_pooled_in_channels=1280, + clip_image_in_channels: Optional[int] = None, + clip_seq=4, + effnet_in_channels: Optional[int] = None, + pixel_mapper_in_channels: Optional[int] = None, + kernel_size=3, + dropout: Union[float, Tuple[float]] = (0.1, 0.1), + self_attn: Union[bool, Tuple[bool]] = True, + timestep_conditioning_type: Tuple[str] = ("sca", "crp"), + switch_level: Optional[Tuple[bool]] = None, + ): + """ + + Parameters: + in_channels (`int`, defaults to 16): + Number of channels in the input sample. + out_channels (`int`, defaults to 16): + Number of channels in the output sample. + timestep_ratio_embedding_dim (`int`, defaults to 64): + Dimension of the projected time embedding. + patch_size (`int`, defaults to 1): + Patch size to use for pixel unshuffling layer + conditioning_dim (`int`, defaults to 2048): + Dimension of the image and text conditional embedding. + block_out_channels (Tuple[int], defaults to (2048, 2048)): + Tuple of output channels for each block. + num_attention_heads (Tuple[int], defaults to (32, 32)): + Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention. + down_num_layers_per_block (Tuple[int], defaults to [8, 24]): + Number of layers in each down block. + up_num_layers_per_block (Tuple[int], defaults to [24, 8]): + Number of layers in each up block. + down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each down block. + up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each up block. + block_types_per_layer (Tuple[Tuple[str]], optional, + defaults to ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock") + ): + Block types used in each layer of the up/down blocks. + clip_text_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for CLIP based text conditioning. + clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): + Number of input channels for pooled CLIP text embeddings. + clip_image_in_channels (`int`, *optional*): + Number of input channels for CLIP based image conditioning. + clip_seq (`int`, *optional*, defaults to 4): + effnet_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for effnet conditioning. + pixel_mapper_in_channels (`int`, defaults to `None`): + Number of input channels for pixel mapper conditioning. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size to use in the block convolutional layers. + dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)): + Dropout to use per block. + self_attn (Union[bool, Tuple[bool]]): + Tuple of booleans that determine whether to use self attention in a block or not. + timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")): + Timestep conditioning type. + switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`): + Tuple that indicates whether upsampling or downsampling should be applied in a block + """ + + super().__init__() + + if len(block_out_channels) != len(down_num_layers_per_block): + raise ValueError( + f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_num_layers_per_block): + raise ValueError( + f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(down_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(block_types_per_layer): + raise ValueError( + f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + if isinstance(dropout, float): + dropout = (dropout,) * len(block_out_channels) + if isinstance(self_attn, bool): + self_attn = (self_attn,) * len(block_out_channels) + + # CONDITIONING + if effnet_in_channels is not None: + self.effnet_mapper = nn.Sequential( + nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + if pixel_mapper_in_channels is not None: + self.pixels_mapper = nn.Sequential( + nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq) + if clip_text_in_channels is not None: + self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim) + if clip_image_in_channels is not None: + self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq) + self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "SDCascadeResBlock": + return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "SDCascadeAttnBlock": + return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "SDCascadeTimestepBlock": + return SDCascadeTimestepBlock( + in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type + ) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(block_out_channels)): + if i > 0: + self.down_downscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2), + ) + ) + else: + self.down_downscalers.append(nn.Identity()) + + down_block = nn.ModuleList() + for _ in range(down_num_layers_per_block[i]): + for block_type in block_types_per_layer[i]: + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + dropout=dropout[i], + self_attn=self_attn[i], + ) + down_block.append(block) + self.down_blocks.append(down_block) + + if down_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(down_blocks_repeat_mappers[i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(block_out_channels))): + if i > 0: + self.up_upscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.ConvTranspose2d( + block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2 + ), + ) + ) + else: + self.up_upscalers.append(nn.Identity()) + + up_block = nn.ModuleList() + for j in range(up_num_layers_per_block[::-1][i]): + for k, block_type in enumerate(block_types_per_layer[i]): + c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0 + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + c_skip=c_skip, + dropout=dropout[i], + self_attn=self_attn[i], + ) + up_block.append(block) + self.up_blocks.append(up_block) + + if up_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(up_blocks_repeat_mappers[::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, value=False): + self.gradient_checkpointing = value + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None + + if hasattr(self, "effnet_mapper"): + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + + if hasattr(self, "pixels_mapper"): + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, SDCascadeResBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0])) + elif isinstance(block, SDCascadeTimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): + r = timestep_ratio * max_positions + half_dim = self.config.timestep_ratio_embedding_dim // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + + if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype=r.dtype) + + def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( + clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1 + ) + if clip_txt is not None and clip_img is not None: + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_img = self.clip_img_mapper(clip_img).view( + clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1 + ) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + else: + clip = clip_txt_pool + return self.clip_norm(clip) + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + elif isinstance(block, SDCascadeAttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, clip, use_reentrant=False + ) + elif isinstance(block, SDCascadeTimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), use_reentrant=False + ) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + else: + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = block(x) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, skip, use_reentrant=False + ) + elif isinstance(block, SDCascadeAttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, clip, use_reentrant=False + ) + elif isinstance(block, SDCascadeTimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False + ) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + else: + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = block(x, skip) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward( + self, + sample, + timestep_ratio, + clip_text_pooled, + clip_text=None, + clip_img=None, + effnet=None, + pixels=None, + sca=None, + crp=None, + return_dict=True, + ): + if pixels is None: + pixels = sample.new_zeros(sample.size(0), 3, 8, 8) + + # Process the conditioning embeddings + timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) + for c in self.config.timestep_conditioning_type: + if c == "sca": + cond = sca + elif c == "crp": + cond = crp + else: + cond = None + t_cond = cond or torch.zeros_like(timestep_ratio) + timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1) + clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) + + # Model Blocks + x = self.embedding(sample) + if hasattr(self, "effnet_mapper") and effnet is not None: + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True) + ) + if hasattr(self, "pixels_mapper"): + x = x + nn.functional.interpolate( + self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True + ) + level_outputs = self._down_encode(x, timestep_ratio_embed, clip) + x = self._up_decode(level_outputs, timestep_ratio_embed, clip) + sample = self.clf(x) + + if not return_dict: + return (sample,) + return StableCascadeUNetOutput(sample=sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a1840201f8ba..d7aacc4d6714 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -176,6 +176,11 @@ _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_cascade"] = [ + "StableCascadeCombinedPipeline", + "StableCascadeDecoderPipeline", + "StableCascadePriorPipeline", + ] _import_structure["stable_diffusion"].extend( [ "CLIPImageProjection", @@ -424,6 +429,11 @@ from .pixart_alpha import PixArtAlphaPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_cascade import ( + StableCascadeCombinedPipeline, + StableCascadeDecoderPipeline, + StableCascadePriorPipeline, + ) from .stable_diffusion import ( CLIPImageProjection, StableDiffusionDepth2ImgPipeline, diff --git a/src/diffusers/pipelines/stable_cascade/__init__.py b/src/diffusers/pipelines/stable_cascade/__init__.py new file mode 100644 index 000000000000..5270cb94af01 --- /dev/null +++ b/src/diffusers/pipelines/stable_cascade/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"] + _import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"] + _import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_cascade import StableCascadeDecoderPipeline + from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline + from .pipeline_stable_cascade_prior import StableCascadePriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py new file mode 100644 index 000000000000..583ef3ee02c2 --- /dev/null +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -0,0 +1,465 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain( + ... "stabilityai/stable-cascade", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class StableCascadeDecoderPipeline(DiffusionPipeline): + """ + Pipeline for generating images from the Stable Cascade model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. + decoder ([`StableCascadeUNet`]): + The Stable Cascade decoder unet. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + unet_name = "decoder" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_pooled", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + decoder: StableCascadeUNet, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + decoder=decoder, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler): + batch_size, channels, height, width = image_embeddings.shape + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + int(height * self.config.latent_dim_scale), + int(width * self.config.latent_dim_scale), + ) + + if latents is None: + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def check_inputs( + self, + prompt, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]], + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 10, + guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + batch_size = image_embeddings.shape[0] + + # 2. Encode caption + if prompt_embeds is None and negative_prompt_embeds is None: + prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + prompt_embeds_pooled = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + timestep_ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise latents + predicted_latents = self.decoder( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=prompt_embeds_pooled, + effnet=effnet, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=timestep_ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py new file mode 100644 index 000000000000..cd3592b49ac0 --- /dev/null +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -0,0 +1,294 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel +from .pipeline_stable_cascade import StableCascadeDecoderPipeline +from .pipeline_stable_cascade_prior import StableCascadePriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusions import StableCascadeCombinedPipeline + + >>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to( + ... "cuda" + ... ) + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class StableCascadeCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModel`): + The decoder text encoder to be used for text inputs. + decoder (`StableCascadeUNet`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + prior_prior (`StableCascadeUNet`): + The prior model to be used for prior pipeline. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + """ + + _load_connected_pipes = True + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_prior: StableCascadeUNet, + prior_text_encoder: CLIPTextModel, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: DDPMWuerstchenScheduler, + prior_feature_extractor: Optional[CLIPImageProcessor] = None, + prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_prior=prior_prior, + prior_scheduler=prior_scheduler, + prior_feature_extractor=prior_feature_extractor, + prior_image_encoder=prior_image_encoder, + ) + self.prior_pipe = StableCascadePriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_encoder=prior_image_encoder, + feature_extractor=prior_feature_extractor, + ) + self.decoder_pipe = StableCascadeDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_timesteps: Optional[List[float]] = None, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_timesteps: Optional[List[float]] = None, + decoder_guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*): + The images to guide the image generation for the prior. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked + to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeine class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + images=images, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=True, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeddings = prior_outputs.image_embeddings + prompt_embeds = prior_outputs.get("prompt_embeds", None) + negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt_embeds is None else None, + num_inference_steps=num_inference_steps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + return outputs diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py new file mode 100644 index 000000000000..8df0d2a991b5 --- /dev/null +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -0,0 +1,614 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class StableCascadePriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.FloatTensor` or `np.ndarray`) + Prior image embeddings for text prompt + prompt_embeds (`torch.FloatTensor`): + Text embeddings for the prompt. + negative_prompt_embeds (`torch.FloatTensor`): + Text embeddings for the negative prompt. + """ + + image_embeddings: Union[torch.FloatTensor, np.ndarray] + prompt_embeds: Union[torch.FloatTensor, np.ndarray] + negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray] + + +class StableCascadePriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`StableCascadeUNet`]): + The Stable Cascade prior to approximate the image embedding from the text and/or image embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "image_encoder->text_encoder->prior" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + prior: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + resolution_multiple: float = 42.67, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config(resolution_multiple=resolution_multiple) + + def prepare_latents( + self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + latent_shape = ( + num_images_per_prompt * batch_size, + self.prior.config.in_channels, + ceil(height / self.config.resolution_multiple), + ceil(width / self.config.resolution_multiple), + ) + + if latents is None: + latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latent_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt): + image_embeds = [] + for image in images: + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + image_embed = self.image_encoder(image).image_embeds.unsqueeze(1) + image_embeds.append(image_embed) + image_embeds = torch.cat(image_embeds, dim=1) + + image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + negative_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, negative_image_embeds + + def check_inputs( + self, + prompt, + images=None, + image_embeds=None, + negative_prompt=None, + prompt_embeds=None, + prompt_embeds_pooled=None, + negative_prompt_embeds=None, + negative_prompt_embeds_pooled=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: + if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: + raise ValueError( + "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" + f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" + f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." + ) + + if image_embeds is not None and images is not None: + raise ValueError( + f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + + if images: + for i, image in enumerate(images): + if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise TypeError( + f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got" + f"{type(image)} for image number {i}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def get_t_condioning(self, t, alphas_cumprod): + s = torch.tensor([0.003]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 20, + timesteps: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pt", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input + argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. + If not provided, image embeddings will be generated from `image` input argument if existing. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated image embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = next(self.prior.parameters()).dtype + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + images=images, + image_embeds=image_embeds, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Encode caption + images + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) + + if images is not None: + image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image( + images=images, + device=device, + dtype=dtype, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + elif image_embeds is not None: + image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) + else: + image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + uncond_image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + + if self.do_classifier_free_guidance: + image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0) + else: + image_embeds = image_embeds_pooled + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + text_encoder_pooled = ( + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if negative_prompt_embeds is not None + else prompt_embeds_pooled + ) + + # 4. Prepare and set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + # 6. Run denoising loop + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=text_encoder_pooled, + clip_text=text_encoder_hidden_states, + clip_img=image_embeds, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t + latents = self.scheduler.step( + model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work + negative_prompt_embeds = ( + negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None + ) # float() as bfloat16-> numpy doesnt work + + if not return_dict: + return (latents, prompt_embeds, negative_prompt_embeds) + + return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py index 7ad4e2f54e7c..a4c6f9c6a8b9 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -1,18 +1,3 @@ -# Copyright (c) 2023 Dominic Rampas MIT License -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import torch import torch.nn as nn diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py index 95707fed0524..6c06cc0e7303 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -233,7 +233,7 @@ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=Tr class ResBlockStageB(nn.Module): - def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index 84ba458daf19..e4277d58a006 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -349,6 +349,11 @@ def __call__( text_encoder_hidden_states = ( torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) # 3. Determine latent shape of latents latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) @@ -371,11 +376,6 @@ def __call__( self._num_timesteps = len(timesteps[:-1]) for i, t in enumerate(self.progress_bar(timesteps[:-1])): ratio = t.expand(latents.size(0)).to(dtype) - effnet = ( - torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) - if self.do_classifier_free_guidance - else image_embeddings - ) # 7. Denoise latents predicted_latents = self.decoder( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, @@ -423,9 +423,9 @@ def __call__( latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type == "np": - images = images.permute(0, 2, 3, 1).cpu().numpy() + images = images.permute(0, 2, 3, 1).cpu().float().numpy() elif output_type == "pil": - images = images.permute(0, 2, 3, 1).cpu().numpy() + images = images.permute(0, 2, 3, 1).cpu().float().numpy() images = self.numpy_to_pil(images) else: images = latents diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 5905bed21866..4640f7696766 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -508,7 +508,7 @@ def __call__( self.maybe_free_model_hooks() if output_type == "np": - latents = latents.cpu().numpy() + latents = latents.cpu().float().numpy() if not return_dict: return (latents,) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 0e3f4b294daa..3d5439878648 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -752,6 +752,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableCascadeCombinedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableCascadeDecoderPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableCascadePriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionAdapterPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/stable_cascade/__init__.py b/tests/pipelines/stable_cascade/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py new file mode 100644 index 000000000000..e717c7733b0a --- /dev/null +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline +from diffusers.models import StableCascadeUNet +from diffusers.pipelines.wuerstchen import PaellaVQModel +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableCascadeCombinedPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "prior_guidance_scale", + "decoder_guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "prior_num_inference_steps", + "output_type", + ] + test_xformers_attention = True + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "conditioning_dim": 128, + "block_out_channels": (128, 128), + "num_attention_heads": (2, 2), + "down_num_layers_per_block": (1, 1), + "up_num_layers_per_block": (1, 1), + "clip_image_in_channels": 768, + "switch_level": (False,), + "clip_text_in_channels": self.text_embedder_hidden_size, + "clip_text_pooled_in_channels": self.text_embedder_hidden_size, + } + + model = StableCascadeUNet(**model_kwargs) + return model.eval() + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + projection_dim=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config).eval() + + @property + def dummy_vqgan(self): + torch.manual_seed(0) + + model_kwargs = { + "bottleneck_blocks": 1, + "num_vq_embeddings": 2, + } + model = PaellaVQModel(**model_kwargs) + return model.eval() + + @property + def dummy_decoder(self): + torch.manual_seed(0) + model_kwargs = { + "in_channels": 4, + "out_channels": 4, + "conditioning_dim": 128, + "block_out_channels": (16, 32, 64, 128), + "num_attention_heads": (-1, -1, 1, 2), + "down_num_layers_per_block": (1, 1, 1, 1), + "up_num_layers_per_block": (1, 1, 1, 1), + "down_blocks_repeat_mappers": (1, 1, 1, 1), + "up_blocks_repeat_mappers": (3, 3, 2, 2), + "block_types_per_layer": ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + "switch_level": None, + "clip_text_pooled_in_channels": 32, + "dropout": (0.1, 0.1, 0.1, 0.1), + } + + model = StableCascadeUNet(**model_kwargs) + return model.eval() + + def get_dummy_components(self): + prior = self.dummy_prior + + scheduler = DDPMWuerstchenScheduler() + tokenizer = self.dummy_tokenizer + text_encoder = self.dummy_text_encoder + decoder = self.dummy_decoder + vqgan = self.dummy_vqgan + prior_text_encoder = self.dummy_text_encoder + prior_tokenizer = self.dummy_tokenizer + + components = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "decoder": decoder, + "scheduler": scheduler, + "vqgan": vqgan, + "prior_text_encoder": prior_text_encoder, + "prior_tokenizer": prior_tokenizer, + "prior_prior": prior, + "prior_scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "prior_guidance_scale": 4.0, + "decoder_guidance_scale": 4.0, + "num_inference_steps": 2, + "prior_num_inference_steps": 2, + "output_type": "np", + "height": 128, + "width": 128, + } + return inputs + + def test_stable_cascade(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[-3:, -3:, -1] + + assert image.shape == (1, 128, 128, 3) + + expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert ( + np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + + @require_torch_gpu + def test_offloads(self): + pipes = [] + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components).to(torch_device) + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_sequential_cpu_offload() + pipes.append(sd_pipe) + + components = self.get_dummy_components() + sd_pipe = self.pipeline_class(**components) + sd_pipe.enable_model_cpu_offload() + pipes.append(sd_pipe) + + image_slices = [] + for pipe in pipes: + inputs = self.get_dummy_inputs(torch_device) + image = pipe(**inputs).images + + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=2e-2) + + @unittest.skip(reason="fp16 not supported") + def test_float16_inference(self): + super().test_float16_inference() + + @unittest.skip(reason="no callback test for combined pipeline") + def test_callback_inputs(self): + super().test_callback_inputs() + + # def test_callback_cfg(self): + # pass + # pass diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py new file mode 100644 index 000000000000..7656744b49c7 --- /dev/null +++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import StableCascadeUNet +from diffusers.pipelines.wuerstchen import PaellaVQModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + load_pt, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableCascadeDecoderPipeline + params = ["prompt"] + batch_params = ["image_embeddings", "prompt", "negative_prompt"] + required_optional_params = [ + "num_images_per_prompt", + "num_inference_steps", + "latents", + "negative_prompt", + "guidance_scale", + "output_type", + "return_dict", + ] + test_xformers_attention = False + callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"] + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + projection_dim=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config).eval() + + @property + def dummy_vqgan(self): + torch.manual_seed(0) + + model_kwargs = { + "bottleneck_blocks": 1, + "num_vq_embeddings": 2, + } + model = PaellaVQModel(**model_kwargs) + return model.eval() + + @property + def dummy_decoder(self): + torch.manual_seed(0) + model_kwargs = { + "in_channels": 4, + "out_channels": 4, + "conditioning_dim": 128, + "block_out_channels": [16, 32, 64, 128], + "num_attention_heads": [-1, -1, 1, 2], + "down_num_layers_per_block": [1, 1, 1, 1], + "up_num_layers_per_block": [1, 1, 1, 1], + "down_blocks_repeat_mappers": [1, 1, 1, 1], + "up_blocks_repeat_mappers": [3, 3, 2, 2], + "block_types_per_layer": [ + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], + ], + "switch_level": None, + "clip_text_pooled_in_channels": 32, + "dropout": [0.1, 0.1, 0.1, 0.1], + } + model = StableCascadeUNet(**model_kwargs) + return model.eval() + + def get_dummy_components(self): + decoder = self.dummy_decoder + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + vqgan = self.dummy_vqgan + + scheduler = DDPMWuerstchenScheduler() + + components = { + "decoder": decoder, + "vqgan": vqgan, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + "latent_dim_scale": 4.0, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image_embeddings": torch.ones((1, 4, 4, 4), device=device), + "prompt": "horse", + "generator": generator, + "guidance_scale": 2.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_wuerstchen_decoder(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False) + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @skip_mps + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=1e-2) + + @skip_mps + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + test_mean_pixel_difference = False + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) + + @unittest.skip(reason="fp16 not supported") + def test_float16_inference(self): + super().test_float16_inference() + + +@slow +@require_torch_gpu +class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_cascade_decoder(self): + pipe = StableCascadeDecoderPipeline.from_pretrained( + "diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + generator = torch.Generator(device="cpu").manual_seed(0) + image_embedding = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt" + ) + + image = pipe( + prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator + ).images[0] + + assert image.size == (1024, 1024) + + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png" + ) + + image_processor = VaeImageProcessor() + + image_np = image_processor.pil_to_numpy(image) + expected_image_np = image_processor.pil_to_numpy(expected_image) + + self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2)) diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py new file mode 100644 index 000000000000..c0ee8cc75963 --- /dev/null +++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline +from diffusers.loaders import AttnProcsLayers +from diffusers.models import StableCascadeUNet +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 +from diffusers.utils.import_utils import is_peft_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_pt, + require_peft_backend, + require_torch_gpu, + skip_mps, + slow, + torch_device, +) + + +if is_peft_available(): + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +def create_prior_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=unet.config.c, + ) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableCascadePriorPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "num_images_per_prompt", + "generator", + "num_inference_steps", + "latents", + "negative_prompt", + "guidance_scale", + "output_type", + "return_dict", + ] + test_xformers_attention = False + callback_cfg_params = ["text_encoder_hidden_states"] + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config).eval() + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "conditioning_dim": 128, + "block_out_channels": (128, 128), + "num_attention_heads": (2, 2), + "down_num_layers_per_block": (1, 1), + "up_num_layers_per_block": (1, 1), + "switch_level": (False,), + "clip_image_in_channels": 768, + "clip_text_in_channels": self.text_embedder_hidden_size, + "clip_text_pooled_in_channels": self.text_embedder_hidden_size, + "dropout": (0.1, 0.1), + } + + model = StableCascadeUNet(**model_kwargs) + return model.eval() + + def get_dummy_components(self): + prior = self.dummy_prior + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + + scheduler = DDPMWuerstchenScheduler() + + components = { + "prior": prior, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + "feature_extractor": None, + "image_encoder": None, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "guidance_scale": 4.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_wuerstchen_prior(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.image_embeddings + + image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0] + + image_slice = image[0, 0, 0, -10:] + image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] + assert image.shape == (1, 16, 24, 24) + + expected_slice = np.array( + [ + 96.139565, + -20.213179, + -116.40341, + -191.57129, + 39.350136, + 74.80767, + 39.782352, + -184.67352, + -46.426907, + 168.41783, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2 + + @skip_mps + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-1) + + @skip_mps + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + test_mean_pixel_difference = False + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) + + @unittest.skip(reason="fp16 not supported") + def test_float16_inference(self): + super().test_float16_inference() + + def check_if_lora_correctly_set(self, model) -> bool: + """ + Checks if the LoRA layers are correctly set with peft + """ + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + def get_lora_components(self): + prior = self.dummy_prior + + prior_lora_config = LoraConfig( + r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False + ) + + prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior) + + lora_components = { + "prior_lora_layers": prior_lora_layers, + "prior_lora_attn_procs": prior_lora_attn_procs, + } + + return prior, prior_lora_config, lora_components + + @require_peft_backend + @unittest.skip(reason="no lora support for now") + def test_inference_with_prior_lora(self): + _, prior_lora_config, _ = self.get_lora_components() + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output_no_lora = pipe(**self.get_dummy_inputs(device)) + image_embed = output_no_lora.image_embeddings + self.assertTrue(image_embed.shape == (1, 16, 24, 24)) + + pipe.prior.add_adapter(prior_lora_config) + self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior") + + output_lora = pipe(**self.get_dummy_inputs(device)) + lora_image_embed = output_lora.image_embeddings + + self.assertTrue(image_embed.shape == lora_image_embed.shape) + + +@slow +@require_torch_gpu +class StableCascadePriorPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_cascade_prior(self): + pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + generator = torch.Generator(device="cpu").manual_seed(0) + + output = pipe(prompt, num_inference_steps=10, generator=generator) + image_embedding = output.image_embeddings + + expected_image_embedding = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt" + ) + + assert image_embedding.shape == (1, 16, 24, 24) + + self.assertTrue( + np.allclose( + image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2 + ) + ) diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 7c85af63bffe..0caed159100a 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -45,7 +45,6 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase "return_dict", "prior_num_inference_steps", "output_type", - "return_dict", ] test_xformers_attention = True