diff --git a/docs/performance/hunyuanvideo.md b/docs/performance/hunyuanvideo.md index 0832bba..0e04d27 100644 --- a/docs/performance/hunyuanvideo.md +++ b/docs/performance/hunyuanvideo.md @@ -10,6 +10,7 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,904.08 | 925.04 | 514.08 | 337.58 | | H20 | 6,639.17 | 3,400.55 | 1,762.86 | 940.97 | +| L20 | 6,043.88 | 3,271.44 | 2,080.05 | | @@ -21,5 +22,6 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,735.01 | 934.09 | 645.45 | 367.02 | | H20 | 6,621.46 | 3,400.55 | 2,310.48 | 1,214.67 | +| L20 | 6,039.08 | 3,260.62 | 2,070.96 | | - \ No newline at end of file + diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py new file mode 100644 index 0000000..0360bf2 --- /dev/null +++ b/examples/hunyuan_video_usp_example.py @@ -0,0 +1,326 @@ +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py +import functools +from typing import Any, Dict, Union, Optional +import logging +import time + +import torch + +from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND +from diffusers.utils import export_to_video + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_world_size, + get_data_parallel_rank, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) + +from xfuser.model_executor.layers.attention_processor import xFuserHunyuanVideoAttnProcessor2_0 + +assert xFuserHunyuanVideoAttnProcessor2_0 is not None + + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logging.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + assert batch_size % get_classifier_free_guidance_world_size( + ) == 0, f"Cannot split dim 0 of hidden_states ({batch_size}) into {get_classifier_free_guidance_world_size()} parts." + + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, + timestep, + encoder_attention_mask) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1) + hidden_states = hidden_states.flatten(1, 3) + + hidden_states = torch.chunk(hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + + encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) + encoder_hidden_states_indices = torch.arange( + encoder_hidden_states.shape[1], + device=encoder_hidden_states.device) + encoder_hidden_states_indices = encoder_hidden_states_indices[ + encoder_attention_mask] + encoder_hidden_states = encoder_hidden_states[ + ..., encoder_hidden_states_indices, :] + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size( + ) != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + + hidden_states = hidden_states.reshape(batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, -1, p_t, p, p) + + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states, ) + + return Transformer2DModelOutput(sample=hidden_states) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserHunyuanVideoAttnProcessor2_0() + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for HunyuanVideo" + + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", + ) + pipe = HunyuanVideoPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", + ) + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + if args.enable_tiling: + pipe.vae.enable_tiling( + # Make it runnable on GPUs with 48GB memory + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + + parameter_peak_memory = torch.cuda.max_memory_allocated( + device=f"cuda:{local_rank}") + + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, + mode="max-autotune-no-cudagraphs") + + # one step to warmup the torch compiler + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=1, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=15) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB" + ) + get_runtime_state().destory_distributed_env() + + +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling --enable_model_cpu_offload +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 544 --width 960 --num_frames 129 --enable_tiling --enable_model_cpu_offload +if __name__ == "__main__": + main() diff --git a/examples/run_hunyuan_video_usp.sh b/examples/run_hunyuan_video_usp.sh new file mode 100755 index 0000000..c1f8813 --- /dev/null +++ b/examples/run_hunyuan_video_usp.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# CogVideoX configuration +SCRIPT="hunyuan_video_usp_example.py" +MODEL_ID="/cfs/dit/HunyuanVideo" +# MODEL_ID="tencent/HunyuanVideo" +INFERENCE_STEP=50 + +mkdir -p ./results + +# CogVideoX specific task args +TASK_ARGS="--height 720 --width 1280 --num_frames 129" + +# CogVideoX parallel configuration +N_GPUS=8 +PARALLEL_ARGS="--ulysses_degree 4 --ring_degree 2" +# CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +ENABLE_TILING="--enable_tiling" +ENABLE_MODEL_CPU_OFFLOAD="--enable_model_cpu_offload" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A cat walks on the grass, realistic" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$ENABLE_TILING \ +$ENABLE_MODEL_CPU_OFFLOAD \ +$COMPILE_FLAG diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index cfb4c79..dd8ec07 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -4,7 +4,7 @@ import numpy as np import torch -from diffusers import DiffusionPipeline, CogVideoXPipeline +from diffusers import DiffusionPipeline import torch.distributed from xfuser.config.config import ( @@ -103,7 +103,12 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): pipeline=pipeline, parallel_config=config.parallel_config ) self.cogvideox = False - if isinstance(pipeline, CogVideoXPipeline): + self.hunyuan_video = False + if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")): + if pipeline.__class__.__name__.startswith("CogVideoX"): + self.cogvideox = True + else: + self.hunyuan_video = True self._set_cogvideox_parameters( vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, @@ -194,7 +199,6 @@ def _set_cogvideox_parameters( self.backbone_patch_size = backbone_patch_size self.backbone_inner_dim = backbone_inner_dim self.backbone_in_channel = backbone_in_channel - self.cogvideox = True def set_patched_mode(self, patch_mode: bool): self.patch_mode = patch_mode @@ -259,6 +263,9 @@ def _video_input_size_change( self.input_config.batch_size = batch_size or self.input_config.batch_size if self.cogvideox: self._calc_cogvideox_patches_metadata() + elif self.hunyuan_video: + # TODO: implement the hunyuan video patches metadata + pass else: self._calc_patches_metadata() self._reset_recv_buffer() diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 9869a7e..31c3fff 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -15,6 +15,11 @@ CogVideoXAttnProcessor2_0 ) +try: + from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoAttnProcessor2_0 +except ImportError: + HunyuanVideoAttnProcessor2_0 = None + from diffusers.models.embeddings import apply_rotary_emb from xfuser.core.distributed import ( @@ -1143,3 +1148,200 @@ def __call__( [text_seq_length, latent_seq_length], dim=1 ) return hidden_states, encoder_hidden_states + + +if HunyuanVideoAttnProcessor2_0 is not None: + @xFuserAttentionProcessorRegister.register(HunyuanVideoAttnProcessor2_0) + class xFuserHunyuanVideoAttnProcessor2_0(HunyuanVideoAttnProcessor2_0): + def __init__(self): + super().__init__() + use_long_ctx_attn_kvcache = True + self.use_long_ctx_attn_kvcache = ( + HAS_LONG_CTX_ATTN + and use_long_ctx_attn_kvcache + and get_sequence_parallel_world_size() > 1 + ) + if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + from xfuser.core.long_ctx_attention import ( + xFuserLongContextAttention, + xFuserUlyssesAttention, + ) + + if HAS_FLASH_ATTN: + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( + use_kv_cache=self.use_long_ctx_attn_kvcache + ) + else: + self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( + use_fa=False, + use_kv_cache=self.use_long_ctx_attn_kvcache, + ) + else: + self.hybrid_seq_parallel_attn = None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + if encoder_hidden_states is not None: + num_encoder_hidden_states_tokens = encoder_hidden_states.shape[1] + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + else: + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + + #! ---------------------------------------- ATTENTION ---------------------------------------- + if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp: + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_runtime_state().split_text_embed_in_sp: + encoder_query = None + encoder_key = None + encoder_value = None + else: + query, encoder_query = query.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + key, encoder_key = key.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + value, encoder_value = value.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = self.hybrid_seq_parallel_attn( + None, + query, + key, + value, + dropout_p=0.0, + causal=False, + joint_tensor_query=encoder_query, + joint_tensor_key=encoder_key, + joint_tensor_value=encoder_value, + joint_strategy="rear", + ) + + hidden_states = hidden_states.flatten(2, 3) + else: + if HAS_FLASH_ATTN: + from flash_attn import flash_attn_func + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + hidden_states = flash_attn_func( + query, key, value, dropout_p=0.0, causal=False + ) + hidden_states = hidden_states.flatten(2, 3) + + else: + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states +else: + xFuserHunyuanVideoAttnProcessor2_0 = None