From f6abe4b02a34104b5971534828d78ea4c8639c57 Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Sat, 14 Sep 2024 11:12:40 +0800 Subject: [PATCH] FLUX.1 text-to-image pipeline for Gaudi * Enabled and tested FLUX pipeline on Gaudi for FLUX.1 class of models * Enabled HPU graphs mode * Enabled batching in inference * Added support for quantization (fp8 and hybrid) * Incorporated Gaudi profiler and HPU Synchronization for performance analysis * Boosted performance with Fused SDPA * Added Fused RoPE * Documented FLUX.1 samples * Upgraded and pinned diffusers in Optimum-Habana to official release 0.31.0 * Resolved issues in other pipelines due to diffusers upgrade * Added CI tests (2 unit tests, 1 slow test for perf and quality) Signed-off-by: Daniel Socek Co-authored-by: Baochen Yang Co-authored-by: Huijuan Zhou Co-authored-by: Sergey Plotnikov Co-authored-by: Deepak Narayana --- examples/stable-diffusion/README.md | 123 ++- .../quantization/flux/measure_config.json | 5 + .../quantization/flux/quantize_config.json | 6 + examples/stable-diffusion/requirements.txt | 3 +- .../text_to_image_generation.py | 60 +- .../unconditional_image_generation.py | 0 optimum/habana/diffusers/__init__.py | 8 +- .../habana/diffusers/models/controlnet_sdv.py | 123 +-- .../diffusers/pipelines/auto_pipeline.py | 4 + .../diffusers/pipelines/flux/pipeline_flux.py | 787 ++++++++++++++++++ .../diffusers/pipelines/pipeline_utils.py | 1 + .../habana/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_euler_discrete.py | 9 + .../scheduling_flow_mactch_euler_discrete.py | 25 + setup.py | 2 +- tests/test_diffusers.py | 183 ++++ 16 files changed, 1204 insertions(+), 136 deletions(-) create mode 100644 examples/stable-diffusion/quantization/flux/measure_config.json create mode 100644 examples/stable-diffusion/quantization/flux/quantize_config.json mode change 100644 => 100755 examples/stable-diffusion/unconditional_image_generation.py create mode 100644 optimum/habana/diffusers/pipelines/flux/pipeline_flux.py create mode 100644 optimum/habana/diffusers/schedulers/scheduling_flow_mactch_euler_discrete.py diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index b922b92498..7c99219e01 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -28,12 +28,12 @@ First, you should install the requirements: pip install -r requirements.txt ``` - ## Text-to-image Generation ### Single Prompt Here is how to generate images with one prompt: + ```bash python text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ @@ -51,10 +51,10 @@ python text_to_image_generation.py \ > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. - ### Multiple Prompts Here is how to generate images with several prompts: + ```bash python text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ @@ -69,7 +69,9 @@ python text_to_image_generation.py \ ``` ### Distributed inference with multiple HPUs + Here is how to generate images with two prompts on two HPUs: + ```bash python ../gaudi_spawn.py \ --world_size 2 text_to_image_generation.py \ @@ -109,10 +111,10 @@ python text_to_image_generation.py \ ``` > There are two different checkpoints for Stable Diffusion 2: +> > - use [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) for generating 768x768 images > - use [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) for generating 512x512 images - ### Latent Diffusion Model for 3D (LDM3D) [LDM3D](https://arxiv.org/abs/2305.10853) generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. @@ -135,7 +137,9 @@ python text_to_image_generation.py \ --ldm3d \ --bf16 ``` + Here is how to generate images and depth maps with two prompts on two HPUs: + ```bash python ../gaudi_spawn.py \ --world_size 2 text_to_image_generation.py \ @@ -154,6 +158,7 @@ python ../gaudi_spawn.py \ ``` > There are three different checkpoints for LDM3D: +> > - use [original checkpoint](https://huggingface.co/Intel/ldm3d) to generate outputs from the paper > - use [the latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) for generating improved results > - use [the pano checkpoint](https://huggingface.co/Intel/ldm3d-pano) to generate panoramic view @@ -163,6 +168,7 @@ python ../gaudi_spawn.py \ Stable Diffusion XL was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/pdf/2307.01952.pdf) by the Stability AI team. Here is how to generate SDXL images with a single prompt: + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ @@ -182,6 +188,7 @@ python text_to_image_generation.py \ > You can enable this mode with `--use_hpu_graphs`. Here is how to generate SDXL images with several prompts: + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ @@ -199,6 +206,7 @@ python text_to_image_generation.py \ SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly increase the number of parameters. Here is how to generate images with several prompts for both `prompt` and `prompt_2` (2nd text encoder), as well as their negative prompts: + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ @@ -217,6 +225,7 @@ python text_to_image_generation.py \ ``` Here is how to generate SDXL images with two prompts on two HPUs: + ```bash python ../gaudi_spawn.py \ --world_size 2 text_to_image_generation.py \ @@ -235,14 +244,17 @@ python ../gaudi_spawn.py \ --bf16 \ --distributed ``` + > HPU graphs are recommended when generating images by batches to get the fastest possible generations. > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. ### SDXL-Turbo + SDXL-Turbo is a distilled version of SDXL 1.0, trained for real-time synthesis. Here is how to generate images with multiple prompts: + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/sdxl-turbo \ @@ -275,11 +287,13 @@ Before running SD3 pipeline, you need to: 1. Agree to the Terms and Conditions for using SD3 model at [HuggingFace model page](https://huggingface.co/stabilityai/stable-diffusion-3-medium) 2. Authenticate with HuggingFace using your HF Token. For authentication, run: + ```bash huggingface-cli login ``` Here is how to generate SD3 images with a single prompt: + ```bash PT_HPU_MAX_COMPOUND_OP_SIZE=1 \ python text_to_image_generation.py \ @@ -299,12 +313,100 @@ python text_to_image_generation.py \ > For improved performance of the SD3 pipeline on Gaudi, it is recommended to configure the environment > by setting PT_HPU_MAX_COMPOUND_OP_SIZE to 1. +### FLUX.1 + +FLUX.1 was introduced by Black Forest Labs [here](https://blackforestlabs.ai/announcing-black-forest-labs/). + +Here is how to run FLUX.1-schnell model (fast version of FLUX.1): + +```bash +python text_to_image_generation.py \ + --model_name_or_path black-forest-labs/FLUX.1-schnell \ + --prompts "A cat holding a sign that says hello world" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 4 \ + --image_save_dir /tmp/flux_1_images \ + --scheduler flow_match_euler_discrete\ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +Before running FLUX.1-dev model, you need to: + +1. Agree to the Terms and Conditions for using FLUX.1-dev model at [HuggingFace model page](https://huggingface.co/black-forest-labs/FLUX.1-dev) +2. Authenticate with HuggingFace using your HF Token. For authentication, run: + +```bash +huggingface-cli login +``` + +Here is how to run FLUX.1-dev model: + +```bash +python text_to_image_generation.py \ + --model_name_or_path black-forest-labs/FLUX.1-dev \ + --prompts "A cat holding a sign that says hello world" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 30 \ + --image_save_dir /tmp/flux_1_images \ + --scheduler flow_match_euler_discrete\ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +This model can also be quantized with some ops running in FP8 precision. + +Before quantization, run stats collection using measure mode: + +```bash +QUANT_CONFIG=quantization/flux/measure_config.json \ +python text_to_image_generation.py \ + --model_name_or_path black-forest-labs/FLUX.1-dev \ + --prompts "A cat holding a sign that says hello world" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 30 \ + --image_save_dir /tmp/flux_1_images \ + --scheduler flow_match_euler_discrete\ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --quant_mode measure +``` + +After stats collection, here is how to run FLUX.1-dev in quantization mode: + +```bash +QUANT_CONFIG=quantization/flux/quantize_config.json \ +python text_to_image_generation.py \ + --model_name_or_path black-forest-labs/FLUX.1-dev \ + --prompts "A cat holding a sign that says hello world" \ + --num_images_per_prompt 10 \ + --batch_size 1 \ + --num_inference_steps 30 \ + --image_save_dir /tmp/flux_1_images \ + --scheduler flow_match_euler_discrete\ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 \ + --quant_mode quantize +``` + ## ControlNet -ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. +ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image. Here is how to generate images conditioned by canny edge model: + ```bash python text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ @@ -321,6 +423,7 @@ python text_to_image_generation.py \ ``` Here is how to generate images conditioned by canny edge model and with multiple prompts: + ```bash python text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ @@ -337,6 +440,7 @@ python text_to_image_generation.py \ ``` Here is how to generate images conditioned by canny edge model and with two prompts on two HPUs: + ```bash python ../gaudi_spawn.py \ --world_size 2 text_to_image_generation.py \ @@ -355,6 +459,7 @@ python ../gaudi_spawn.py \ ``` Here is how to generate images conditioned by open pose model: + ```bash python text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ @@ -372,6 +477,7 @@ python text_to_image_generation.py \ ``` Here is how to generate images with conditioned by canny edge model using Stable Diffusion 2 + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-2-1 \ @@ -395,6 +501,7 @@ Inpainting replaces or edits specific areas of an image. For more details, please refer to [Hugging Face Diffusers doc](https://huggingface.co/docs/diffusers/en/using-diffusers/inpaint). ### Stable Diffusion Inpainting + ```bash python text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-2-inpainting \ @@ -412,6 +519,7 @@ python text_to_image_generation.py \ ``` ### Stable Diffusion XL Inpainting + ```bash python text_to_image_generation.py \ --model_name_or_path diffusers/stable-diffusion-xl-1.0-inpainting-0.1\ @@ -457,10 +565,10 @@ python image_to_image_generation.py \ > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. - ### Multiple Prompts Here is how to generate images with several prompts and one image. + ```bash python image_to_image_generation.py \ --model_name_or_path "timbrooks/instruct-pix2pix" \ @@ -482,10 +590,10 @@ python image_to_image_generation.py \ > The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. > You can enable this mode with `--use_hpu_graphs`. - ### Stable Diffusion XL Refiner Here is how to generate SDXL images with a single prompt and one image: + ```bash python image_to_image_generation.py \ --model_name_or_path "stabilityai/stable-diffusion-xl-refiner-1.0" \ @@ -505,6 +613,7 @@ python image_to_image_generation.py \ ### Stable Diffusion Image Variations Here is how to generate images with one image, it does not accept prompt input + ```bash python image_to_image_generation.py \ --model_name_or_path "lambdalabs/sd-image-variations-diffusers" \ @@ -625,6 +734,7 @@ Script `image_to_video_generation.py` showcases how to perform image-to-video ge ### Single Image Prompt Here is how to generate video with one image prompt: + ```bash PT_HPU_MAX_COMPOUND_OP_SIZE=1 \ python image_to_video_generation.py \ @@ -645,6 +755,7 @@ python image_to_video_generation.py \ ### Multiple Image Prompts Here is how to generate videos with several image prompts: + ```bash PT_HPU_MAX_COMPOUND_OP_SIZE=1 \ python image_to_video_generation.py \ diff --git a/examples/stable-diffusion/quantization/flux/measure_config.json b/examples/stable-diffusion/quantization/flux/measure_config.json new file mode 100644 index 0000000000..865078d99f --- /dev/null +++ b/examples/stable-diffusion/quantization/flux/measure_config.json @@ -0,0 +1,5 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "dump_stats_path": "quantization/flux/measure_all/fp8" +} diff --git a/examples/stable-diffusion/quantization/flux/quantize_config.json b/examples/stable-diffusion/quantization/flux/quantize_config.json new file mode 100644 index 0000000000..8fdb21fccf --- /dev/null +++ b/examples/stable-diffusion/quantization/flux/quantize_config.json @@ -0,0 +1,6 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "scale_method": "maxabs_hw_opt_weight", + "dump_stats_path": "quantization/flux/measure_all/fp8" +} diff --git a/examples/stable-diffusion/requirements.txt b/examples/stable-diffusion/requirements.txt index a63e739620..ed24d8c1b7 100644 --- a/examples/stable-diffusion/requirements.txt +++ b/examples/stable-diffusion/requirements.txt @@ -1,2 +1,3 @@ opencv-python -compel \ No newline at end of file +compel +sentencepiece diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index e83d455237..59d8989f92 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -27,6 +27,7 @@ GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler, + GaudiFlowMatchEulerDiscreteScheduler, ) from optimum.habana.utils import set_seed @@ -66,7 +67,7 @@ def main(): parser.add_argument( "--scheduler", default="ddim", - choices=["default", "euler_discrete", "euler_ancestral_discrete", "ddim"], + choices=["default", "euler_discrete", "euler_ancestral_discrete", "ddim", "flow_match_euler_discrete"], type=str, help="Name of scheduler", ) @@ -287,6 +288,18 @@ def main(): help="Use rescale_betas_zero_snr for controlling image brightness", ) parser.add_argument("--optimize", action="store_true", help="Use optimized pipeline.") + parser.add_argument( + "--quant_mode", + default="disable", + type=str, + help="Quantization mode 'measure', 'quantize', 'quantize-mixed' or 'disable'", + ) + parser.add_argument( + "--prompts_file", + type=str, + default=None, + help="The file with prompts (for large number of images generation).", + ) args = parser.parse_args() if args.optimize and not args.use_habana: @@ -295,14 +308,21 @@ def main(): # Select stable diffuson pipeline based on input sdxl_models = ["stable-diffusion-xl", "sdxl"] sd3_models = ["stable-diffusion-3"] + flux_models = ["FLUX.1"] sdxl = True if any(model in args.model_name_or_path for model in sdxl_models) else False sd3 = True if any(model in args.model_name_or_path for model in sd3_models) else False + flux = True if any(model in args.model_name_or_path for model in flux_models) else False controlnet = True if args.control_image is not None else False inpainting = True if (args.base_image is not None) and (args.mask_image is not None) else False # Set the scheduler kwargs = {"timestep_spacing": args.timestep_spacing, "rescale_betas_zero_snr": args.use_zero_snr} - if args.scheduler == "euler_discrete": + + if flux or args.scheduler == "flow_match_euler_discrete": + scheduler = GaudiFlowMatchEulerDiscreteScheduler.from_pretrained( + args.model_name_or_path, subfolder="scheduler", **kwargs + ) + elif args.scheduler == "euler_discrete": scheduler = GaudiEulerDiscreteScheduler.from_pretrained( args.model_name_or_path, subfolder="scheduler", **kwargs ) @@ -362,16 +382,18 @@ def main(): negative_prompts = negative_prompt kwargs_call["negative_prompt"] = negative_prompts - if sdxl or sd3: + if sdxl or sd3 or flux: prompts_2 = args.prompts_2 - negative_prompts_2 = args.negative_prompts_2 if args.distributed and args.prompts_2 is not None: with distributed_state.split_between_processes(args.prompts_2) as prompt_2: prompts_2 = prompt_2 + kwargs_call["prompt_2"] = prompts_2 + + if sdxl or sd3: + negative_prompts_2 = args.negative_prompts_2 if args.distributed and args.negative_prompts_2 is not None: with distributed_state.split_between_processes(args.negative_prompts_2) as negative_prompt_2: negative_prompts_2 = negative_prompt_2 - kwargs_call["prompt_2"] = prompts_2 kwargs_call["negative_prompt_2"] = negative_prompts_2 if sd3: @@ -410,7 +432,11 @@ def main(): control_image = Image.fromarray(image) kwargs_call["image"] = control_image + kwargs_call["quant_mode"] = args.quant_mode + # Instantiate a Stable Diffusion pipeline class + import habana_frameworks.torch.core as htcore # noqa: F401 + if sdxl: # SDXL pipelines if controlnet: @@ -469,6 +495,22 @@ def main(): args.model_name_or_path, **kwargs, ) + elif flux: + # Flux pipelines + if controlnet: + # Import Flux+ControlNet pipeline + raise ValueError("Flux+ControlNet pipeline is not currenly supported") + elif inpainting: + # Import FLux Inpainting pipeline + raise ValueError("Flux Inpainting pipeline is not currenly supported") + else: + # Import Flux pipeline + from optimum.habana.diffusers import GaudiFluxPipeline + + pipeline = GaudiFluxPipeline.from_pretrained( + args.model_name_or_path, + **kwargs, + ) else: # SD pipelines (SD1.x, SD2.x) @@ -561,6 +603,14 @@ def main(): pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.5, b2=1.6) + # If prompts file is specified override prompts from the file + if args.prompts_file is not None: + lines = [] + with open(args.prompts_file, "r") as file: + lines = file.readlines() + lines = [line.strip() for line in lines] + args.prompts = lines + # Generate Images using a Stable Diffusion pipeline if args.distributed: with distributed_state.split_between_processes(args.prompts) as prompt: diff --git a/examples/stable-diffusion/unconditional_image_generation.py b/examples/stable-diffusion/unconditional_image_generation.py old mode 100644 new mode 100755 diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index de76c24f5e..19c9e5fb1b 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -4,6 +4,7 @@ GaudiStableVideoDiffusionControlNetPipeline, ) from .pipelines.ddpm.pipeline_ddpm import GaudiDDPMPipeline +from .pipelines.flux.pipeline_flux import GaudiFluxPipeline from .pipelines.pipeline_utils import GaudiDiffusionPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion_depth2img import GaudiStableDiffusionDepth2ImgPipeline @@ -23,4 +24,9 @@ from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import GaudiStableDiffusionXLInpaintPipeline from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline from .pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import GaudiTextToVideoSDPipeline -from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler +from .schedulers import ( + GaudiDDIMScheduler, + GaudiEulerAncestralDiscreteScheduler, + GaudiEulerDiscreteScheduler, + GaudiFlowMatchEulerDiscreteScheduler, +) diff --git a/optimum/habana/diffusers/models/controlnet_sdv.py b/optimum/habana/diffusers/models/controlnet_sdv.py index 2cc62d7abb..70c9994bc9 100644 --- a/optimum/habana/diffusers/models/controlnet_sdv.py +++ b/optimum/habana/diffusers/models/controlnet_sdv.py @@ -13,18 +13,14 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models import UNetSpatioTemporalConditionModel from diffusers.models.attention_processor import ( - ADDED_KV_ATTENTION_PROCESSORS, - CROSS_ATTENTION_PROCESSORS, AttentionProcessor, - AttnAddedKVProcessor, - AttnProcessor, ) from diffusers.models.embeddings import ( TimestepEmbedding, @@ -550,123 +546,6 @@ def from_unet( return controlnet - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - def zero_module(module): for p in module.parameters(): diff --git a/optimum/habana/diffusers/pipelines/auto_pipeline.py b/optimum/habana/diffusers/pipelines/auto_pipeline.py index 77171c9502..f2cd06b6a0 100644 --- a/optimum/habana/diffusers/pipelines/auto_pipeline.py +++ b/optimum/habana/diffusers/pipelines/auto_pipeline.py @@ -29,8 +29,10 @@ from huggingface_hub.utils import validate_hf_hub_args from .controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline +from .flux.pipeline_flux import GaudiFluxPipeline from .stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline from .stable_diffusion.pipeline_stable_diffusion_inpaint import GaudiStableDiffusionInpaintPipeline +from .stable_diffusion_3.pipeline_stable_diffusion_3 import GaudiStableDiffusion3Pipeline from .stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline from .stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import GaudiStableDiffusionXLInpaintPipeline @@ -42,6 +44,8 @@ ("stable-diffusion", GaudiStableDiffusionPipeline), ("stable-diffusion-xl", GaudiStableDiffusionXLPipeline), ("stable-diffusion-controlnet", GaudiStableDiffusionControlNetPipeline), + ("stable-diffusion-3", GaudiStableDiffusion3Pipeline), + ("flux", GaudiFluxPipeline), ] ) diff --git a/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py b/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 0000000000..b43585acfc --- /dev/null +++ b/optimum/habana/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,787 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps +from diffusers.utils import BaseOutput, replace_example_docstring +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from optimum.utils import logging + +from ....transformers.gaudi_configuration import GaudiConfig +from ....utils import HabanaProfile, speed_metrics, warmup_inference_steps_time_adjustment +from ...schedulers import GaudiFlowMatchEulerDiscreteScheduler +from ..pipeline_utils import GaudiDiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class GaudiFluxPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + throughput: float + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from optimum.habana.diffusers import GaudiFluxPipeline + + >>> pipe = GaudiFluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", + ... torch_dtype=torch.bfloat16, + ... use_habana=True, + ... use_hpu_graphs=True, + ... gaudi_config="Habana/stable-diffusion", + ... ) + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 + """ + cos_, sin_ = freqs_cis # [S, D] + + cos = cos_[None, None] + sin = sin_[None, None] + cos, sin = cos.to(xq.device), sin.to(xq.device) + + xq_out = torch.ops.hpu.rotary_pos_embedding(xq, sin, cos, None, 0, 1) + xk_out = torch.ops.hpu.rotary_pos_embedding(xk, sin, cos, None, 0, 1) + + return xq_out, xk_out + + +class GaudiFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if image_rotary_emb is not None: + query, key = apply_rotary_emb(query, key, image_rotary_emb) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + hidden_states = FusedSDPA.apply(query, key, value, None, 0.0, False, None, "fast", None) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class GaudiFusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + if image_rotary_emb is not None: + query, key = apply_rotary_emb(query, key, image_rotary_emb) + + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + hidden_states = FusedSDPA.apply(query, key, value, None, 0.0, False, None, "fast", None) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class GaudiFluxPipeline(GaudiDiffusionPipeline, FluxPipeline): + r""" + Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/pipelines/flux/pipeline_flux.py#L140 + + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: GaudiFlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + FluxPipeline.__init__( + self, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + ) + + for block in self.transformer.single_transformer_blocks: + block.attn.processor = GaudiFluxAttnProcessor2_0() + for block in self.transformer.transformer_blocks: + block.attn.processor = GaudiFluxAttnProcessor2_0() + + self.to(self._device) + if use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + transformer = wrap_in_hpu_graph(transformer) + + @classmethod + def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_prompt_embeds, guidance): + # Use torch.split to generate num_batches batches of size batch_size + latents_batches = list(torch.split(latents, batch_size)) + prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds_batches = list(torch.split(pooled_prompt_embeds, batch_size)) + if guidance is not None: + guidance_batches = list(torch.split(guidance, batch_size)) + + # If the last batch has less samples than batch_size, pad it with dummy samples + num_dummy_samples = 0 + if latents_batches[-1].shape[0] < batch_size: + num_dummy_samples = batch_size - latents_batches[-1].shape[0] + + # Pad latents_batches + sequence_to_stack = (latents_batches[-1],) + tuple( + torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + latents_batches[-1] = torch.vstack(sequence_to_stack) + + # Pad prompt_embeds_batches + sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + + # Pad pooled_prompt_embeds if necessary + if pooled_prompt_embeds is not None: + sequence_to_stack = (pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(pooled_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + + # Pad guidance if necessary + if guidance is not None: + sequence_to_stack = (guidance_batches[-1],) + tuple( + torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + guidance_batches[-1] = torch.vstack(sequence_to_stack) + + # Stack batches in the same tensor + latents_batches = torch.stack(latents_batches) + prompt_embeds_batches = torch.stack(prompt_embeds_batches) + pooled_prompt_embeds_batches = torch.stack(pooled_prompt_embeds_batches) + guidance_batches = torch.stack(guidance_batches) if guidance is not None else None + + return ( + latents_batches, + prompt_embeds_batches, + pooled_prompt_embeds_batches, + guidance_batches, + num_dummy_samples, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + batch_size: int = 1, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + profiling_warmup_steps: Optional[int] = 0, + profiling_steps: Optional[int] = 0, + **kwargs, + ): + r""" + Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py#L531 + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + profiling_warmup_steps (`int`, *optional*): + Number of steps to ignore for profling. + profiling_steps (`int`, *optional*): + Number of steps to be captured when enabling profiling. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + import habana_frameworks.torch as ht + import habana_frameworks.torch.core as htcore + + quant_mode = kwargs.get("quant_mode", None) + + if quant_mode == "quantize-mixed": + import copy + + transformer_bf16 = copy.deepcopy(self.transformer).to(self._execution_device) + + if quant_mode in ("measure", "quantize", "quantize-mixed"): + import os + + quant_config_path = os.getenv("QUANT_CONFIG") + if not quant_config_path: + raise ImportError( + "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." + ) + elif not os.path.isfile(quant_config_path): + raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") + + htcore.hpu_set_env() + + from neural_compressor.torch.quantization import FP8Config, convert, prepare + + config = FP8Config.from_json_file(quant_config_path) + if config.measure: + self.transformer = prepare(self.transformer, config) + elif config.quantize: + self.transformer = convert(self.transformer, config) + htcore.hpu_initialize(self.transformer, mark_only_scales_as_const=True) + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + num_batches = math.ceil((num_images_per_prompt * num_prompts) / batch_size) + + device = self._execution_device + + # 3. Run text encoder + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + num_prompts * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + logger.info( + f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + + throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3) + use_warmup_inference_steps = ( + num_batches <= throughput_warmup_steps and num_inference_steps > throughput_warmup_steps + ) + + ht.hpu.synchronize() + t0 = time.time() + t1 = t0 + + hb_profiler = HabanaProfile( + warmup=profiling_warmup_steps, + active=profiling_steps, + record_shapes=False, + ) + hb_profiler.start() + + # 5.1. Split Input data to batches (HPU-specific step) + ( + latents_batches, + text_embeddings_batches, + pooled_prompt_embeddings_batches, + guidance_batches, + num_dummy_samples, + ) = self._split_inputs_into_batches(batch_size, latents, prompt_embeds, pooled_prompt_embeds, guidance) + + outputs = { + "images": [], + } + + # 6. Denoising loop + for j in range(num_batches): + # The throughput is calculated from the 4th iteration + # because compilation occurs in the first 2-3 iterations + if j == throughput_warmup_steps: + ht.hpu.synchronize() + t1 = time.time() + + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + pooled_prompt_embeddings_batch = pooled_prompt_embeddings_batches[0] + pooled_prompt_embeddings_batches = torch.roll(pooled_prompt_embeddings_batches, shifts=-1, dims=0) + guidance_batch = None if guidance_batches is None else guidance_batches[0] + guidance_batches = None if guidance_batches is None else torch.roll(guidance_batches, shifts=-1, dims=0) + + if hasattr(self.scheduler, "_init_step_index"): + # Reset scheduler step index for next batch + self.scheduler.timesteps = timesteps + self.scheduler._init_step_index(timesteps[0]) + + # Mixed quantization + quant_mixed_step = len(timesteps) + if quant_mode == "quantize-mixed": + # 10% of steps use higher precision in mixed quant mode + quant_mixed_step = quant_mixed_step - (quant_mixed_step // 10) + print(f"Use FP8 Transformer at steps 0 to {quant_mixed_step - 1}") + print(f"Use BF16 Transformer at steps {quant_mixed_step} to {len(timesteps) - 1}") + + for i in self.progress_bar(range(len(timesteps))): + if use_warmup_inference_steps and i == throughput_warmup_steps and j == num_batches - 1: + ht.hpu.synchronize() + t1 = time.time() + + if self.interrupt: + continue + + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.expand(latents_batch.shape[0]).to(latents_batch.dtype) + + if i >= quant_mixed_step: + # Mixed quantization + noise_pred = transformer_bf16( + hidden_states=latents_batch, + timestep=timestep / 1000, + guidance=guidance_batch, + pooled_projections=pooled_prompt_embeddings_batch, + encoder_hidden_states=text_embeddings_batch, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred = self.transformer( + hidden_states=latents_batch, + timestep=timestep / 1000, + guidance=guidance_batch, + pooled_projections=pooled_prompt_embeddings_batch, + encoder_hidden_states=text_embeddings_batch, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_batch = self.scheduler.step(noise_pred, timestep, latents_batch, return_dict=False)[0] + + hb_profiler.step() + # htcore.mark_step(sync=True) + if num_batches > throughput_warmup_steps: + ht.hpu.synchronize() + + if not output_type == "latent": + latents_batch = self._unpack_latents(latents_batch, height, width, self.vae_scale_factor) + latents_batch = (latents_batch / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents_batch, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents_batch + + outputs["images"].append(image) + # htcore.mark_step(sync=True) + + # 7. Stage after denoising + hb_profiler.stop() + + if quant_mode == "measure": + from neural_compressor.torch.quantization import finalize_calibration + + finalize_calibration(self.transformer) + + ht.hpu.synchronize() + speed_metrics_prefix = "generation" + if use_warmup_inference_steps: + t1 = warmup_inference_steps_time_adjustment(t1, t1, num_inference_steps, throughput_warmup_steps) + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=batch_size + if t1 == t0 or use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size, + num_steps=batch_size * num_inference_steps + if use_warmup_inference_steps + else (num_batches - throughput_warmup_steps) * batch_size * num_inference_steps, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + # 8 Output Images + if num_dummy_samples > 0: + # Remove dummy generations if needed + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] + + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + if output_type == "pil" and isinstance(image, list): + outputs["images"] += image + elif output_type in ["np", "numpy"] and isinstance(image, np.ndarray): + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = np.concatenate((outputs["images"], image), axis=0) + else: + if len(outputs["images"]) == 0: + outputs["images"] = image + else: + outputs["images"] = torch.cat((outputs["images"], image), 0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return outputs["images"] + + return GaudiFluxPipelineOutput( + images=outputs["images"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index 7f36b90ae4..6e659edff4 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -55,6 +55,7 @@ "optimum.habana.diffusers.schedulers": { "GaudiDDIMScheduler": ["save_pretrained", "from_pretrained"], "GaudiEulerDiscreteScheduler": ["save_pretrained", "from_pretrained"], + "GaudiFlowMatchEulerDiscreteScheduler": ["save_pretrained", "from_pretrained"], "GaudiEulerAncestralDiscreteScheduler": ["save_pretrained", "from_pretrained"], }, } diff --git a/optimum/habana/diffusers/schedulers/__init__.py b/optimum/habana/diffusers/schedulers/__init__.py index 37eb80b1a6..48bf0bd8e9 100644 --- a/optimum/habana/diffusers/schedulers/__init__.py +++ b/optimum/habana/diffusers/schedulers/__init__.py @@ -1,3 +1,4 @@ from .scheduling_ddim import GaudiDDIMScheduler from .scheduling_euler_ancestral_discrete import GaudiEulerAncestralDiscreteScheduler from .scheduling_euler_discrete import GaudiEulerDiscreteScheduler +from .scheduling_flow_mactch_euler_discrete import GaudiFlowMatchEulerDiscreteScheduler diff --git a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py index 977b196e29..e3344455b9 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py +++ b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py @@ -54,6 +54,11 @@ class GaudiEulerDiscreteScheduler(EulerDiscreteScheduler): use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -78,6 +83,8 @@ def __init__( prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", @@ -94,6 +101,8 @@ def __init__( prediction_type, interpolation_type, use_karras_sigmas, + use_exponential_sigmas, + use_beta_sigmas, sigma_min, sigma_max, timestep_spacing, diff --git a/optimum/habana/diffusers/schedulers/scheduling_flow_mactch_euler_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_flow_mactch_euler_discrete.py new file mode 100644 index 0000000000..8a2ed48972 --- /dev/null +++ b/optimum/habana/diffusers/schedulers/scheduling_flow_mactch_euler_discrete.py @@ -0,0 +1,25 @@ +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + +class GaudiFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + # TODO: overwrite orginal func with following one to fix dyn error in gaudi lazy mode + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + # indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + # pos = 1 if len(indices) > 1 else 0 + + # return indices[pos].item() + + masked = schedule_timesteps == timestep + tmp = masked.cumsum(dim=0) + pos = (tmp == 0).sum().item() + if masked.sum() > 1: + pos += (tmp == 1).sum().item() + return pos diff --git a/setup.py b/setup.py index 4249e21924..0bb36466ee 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "optimum", "torch", "accelerate >= 0.33.0, < 0.34.0", - "diffusers == 0.29.2", + "diffusers >= 0.31.0, < 0.32.0", "huggingface_hub >= 0.24.7", "sentence-transformers == 3.2.1", ] diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index b2526c7fa6..288264a156 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -48,6 +48,7 @@ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, + FluxTransformer2DModel, LCMScheduler, PNDMScheduler, SD3Transformer2DModel, @@ -87,6 +88,7 @@ DPTConfig, DPTFeatureExtractor, DPTForDepthEstimation, + T5EncoderModel, ) from transformers.testing_utils import parse_flag_from_env, slow @@ -97,6 +99,7 @@ GaudiDiffusionPipeline, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler, + GaudiFluxPipeline, GaudiStableDiffusion3Pipeline, GaudiStableDiffusionControlNetPipeline, GaudiStableDiffusionDepth2ImgPipeline, @@ -141,6 +144,7 @@ DEPTH2IMG_GENERATION_LATENCY_BASELINE_BF16 = 36.06376791000366 TEXTUAL_INVERSION_SDXL_THROUGHPUT = 2.6694 TEXTUAL_INVERSION_SDXL_RUNTIME = 74.92 + FLUX_THROUGHPUT = 0.03 else: THROUGHPUT_BASELINE_BF16 = 0.309 THROUGHPUT_BASELINE_AUTOCAST = 0.114 @@ -6351,3 +6355,182 @@ def test_no_throughput_regression_bf16(self): ) outputs = pipe(batch_size=batch_size) self.assertGreaterEqual(outputs.throughput, 0.95 * THROUGHPUT_UNCONDITIONAL_IMAGE_BASELINE_BF16) + + +class GaudiFluxPipelineTester(TestCase): + """ + Tests the Flux pipeline for Gaudi. + """ + + pipeline_class = GaudiFluxPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "prompt_embeds", + "pooled_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt"]) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=32, + pooled_projection_dim=32, + axes_dims_rope=[4, 4, 8], + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = CLIPTextModel(clip_text_encoder_config) + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + # HF issue with T5EncoderModel from tiny-random-t5 + # text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + # text_encoder_3 = None + + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 8, + "width": 8, + "max_sequence_length": 48, + "output_type": "np", + } + return inputs + + def test_flux_different_prompts(self): + pipe = self.pipeline_class( + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + **self.get_dummy_components(), + ).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt_2"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + # For some reasons, they don't show large differences + assert max_diff > 1e-6 + + def test_flux_prompt_embeds(self): + pipe = self.pipeline_class( + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + **self.get_dummy_components(), + ).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + (prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt( + prompt, + prompt_2=None, + device=torch_device, + max_sequence_length=inputs["max_sequence_length"], + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + @slow + @pytest.mark.skipif(not IS_GAUDI2, reason="does not fit into Gaudi1 memory") + def test_flux_inference(self): + repo_id = "black-forest-labs/FLUX.1-schnell" + + pipe = self.pipeline_class.from_pretrained( + repo_id, + torch_dtype=torch.bfloat16, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + ) + + generator = torch.Generator(device="cpu").manual_seed(0) + + outputs = pipe( + prompt="A photo of a cat", + num_inference_steps=5, + guidance_scale=5.0, + output_type="np", + generator=generator, + ) + + # Check expected performance of FLUX.1 schnell model + self.assertGreaterEqual(outputs.throughput, 0.95 * FLUX_THROUGHPUT)