From df8f8534ec8dd1a18aef6fea31de4b897497c195 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 19 Dec 2024 17:38:08 +0800 Subject: [PATCH 1/7] add hunyuan_video_usp_example.py --- examples/hunyuan_video_usp_example.py | 302 ++++++++++++++++++ xfuser/core/distributed/runtime_state.py | 4 +- .../layers/attention_processor.py | 202 ++++++++++++ 3 files changed, 506 insertions(+), 2 deletions(-) create mode 100644 examples/hunyuan_video_usp_example.py diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py new file mode 100644 index 0000000..1b79632 --- /dev/null +++ b/examples/hunyuan_video_usp_example.py @@ -0,0 +1,302 @@ +import functools +from typing import Any, Dict, Union +import logging +import time + +import torch + +from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import export_to_video + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_world_size, + get_data_parallel_rank, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) + +from xfuser.model_executor.layers.attention_processor import xFuserHunyuanVideoAttnProcessor2_0 + +assert xFuserHunyuanVideoAttnProcessor2_0 is not None + + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + assert batch_size % get_classifier_free_guidance_world_size( + ) == 0, f"Cannot split dim 0 of hidden_states ({batch_size}) into {get_classifier_free_guidance_world_size()} parts." + + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, + timestep, + encoder_attention_mask) + + encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) + encoder_hidden_states_indices = torch.arange( + encoder_hidden_states.shape[1], + device=encoder_hidden_states.device) + encoder_hidden_states_indices = encoder_hidden_states_indices[ + encoder_attention_mask] + encoder_hidden_states = encoder_hidden_states[ + ..., encoder_hidden_states_indices, :] + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size( + ) != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + hidden_states = torch.chunk(hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + freqs = torch.chunk(freqs, + get_sequence_parallel_world_size(), + dim=0)[get_sequence_parallel_rank()] + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + + hidden_states = hidden_states.reshape(batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, -1, p_t, p, p) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states, ) + + return Transformer2DModelOutput(sample=hidden_states) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserHunyuanVideoAttnProcessor2_0() + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for HunyuanVideo" + + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", + ) + pipe = HunyuanVideoPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", + ) + + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + if args.enable_tiling: + pipe.vae.enable_tiling( + # Make it runnable on GPUs with 48GB memory + tile_sample_min_height=128, + tile_sample_stride_height=96, + tile_sample_min_width=128, + tile_sample_stride_width=96, + tile_sample_min_num_frames=32, + tile_sample_stride_num_frames=24, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + + parameter_peak_memory = torch.cuda.max_memory_allocated( + device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, + mode="max-autotune-no-cudagraphs") + + # one step to warmup the torch compiler + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=1, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=15) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB" + ) + get_runtime_state().destory_distributed_env() + + +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling +if __name__ == "__main__": + main() diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index cfb4c79..7de0606 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -4,7 +4,7 @@ import numpy as np import torch -from diffusers import DiffusionPipeline, CogVideoXPipeline +from diffusers import DiffusionPipeline import torch.distributed from xfuser.config.config import ( @@ -103,7 +103,7 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): pipeline=pipeline, parallel_config=config.parallel_config ) self.cogvideox = False - if isinstance(pipeline, CogVideoXPipeline): + if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")): self._set_cogvideox_parameters( vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 9869a7e..31c3fff 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -15,6 +15,11 @@ CogVideoXAttnProcessor2_0 ) +try: + from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoAttnProcessor2_0 +except ImportError: + HunyuanVideoAttnProcessor2_0 = None + from diffusers.models.embeddings import apply_rotary_emb from xfuser.core.distributed import ( @@ -1143,3 +1148,200 @@ def __call__( [text_seq_length, latent_seq_length], dim=1 ) return hidden_states, encoder_hidden_states + + +if HunyuanVideoAttnProcessor2_0 is not None: + @xFuserAttentionProcessorRegister.register(HunyuanVideoAttnProcessor2_0) + class xFuserHunyuanVideoAttnProcessor2_0(HunyuanVideoAttnProcessor2_0): + def __init__(self): + super().__init__() + use_long_ctx_attn_kvcache = True + self.use_long_ctx_attn_kvcache = ( + HAS_LONG_CTX_ATTN + and use_long_ctx_attn_kvcache + and get_sequence_parallel_world_size() > 1 + ) + if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + from xfuser.core.long_ctx_attention import ( + xFuserLongContextAttention, + xFuserUlyssesAttention, + ) + + if HAS_FLASH_ATTN: + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( + use_kv_cache=self.use_long_ctx_attn_kvcache + ) + else: + self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( + use_fa=False, + use_kv_cache=self.use_long_ctx_attn_kvcache, + ) + else: + self.hybrid_seq_parallel_attn = None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + if encoder_hidden_states is not None: + num_encoder_hidden_states_tokens = encoder_hidden_states.shape[1] + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + else: + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + + #! ---------------------------------------- ATTENTION ---------------------------------------- + if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp: + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_runtime_state().split_text_embed_in_sp: + encoder_query = None + encoder_key = None + encoder_value = None + else: + query, encoder_query = query.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + key, encoder_key = key.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + value, encoder_value = value.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = self.hybrid_seq_parallel_attn( + None, + query, + key, + value, + dropout_p=0.0, + causal=False, + joint_tensor_query=encoder_query, + joint_tensor_key=encoder_key, + joint_tensor_value=encoder_value, + joint_strategy="rear", + ) + + hidden_states = hidden_states.flatten(2, 3) + else: + if HAS_FLASH_ATTN: + from flash_attn import flash_attn_func + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + hidden_states = flash_attn_func( + query, key, value, dropout_p=0.0, causal=False + ) + hidden_states = hidden_states.flatten(2, 3) + + else: + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states +else: + xFuserHunyuanVideoAttnProcessor2_0 = None From 8bc8e8d03b7c812064fac83f2f14d8739c585192 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 19 Dec 2024 17:40:06 +0800 Subject: [PATCH 2/7] fix --- examples/hunyuan_video_usp_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index 1b79632..808d0a3 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -1,3 +1,4 @@ +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py import functools from typing import Any, Dict, Union import logging From 88102b9f132d7fb9dbdb8661fe440be7efbc02c0 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Fri, 20 Dec 2024 11:20:10 +0800 Subject: [PATCH 3/7] make hunyuan video work with --enable_model_cpu_offload --- examples/hunyuan_video_usp_example.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index 808d0a3..9895a3f 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -207,6 +207,18 @@ def main(): revision="refs/pr/18", ) + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + if args.enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload(gpu_id=local_rank) logging.info(f"rank {local_rank} sequential CPU offload enabled") @@ -234,18 +246,6 @@ def main(): parameter_peak_memory = torch.cuda.max_memory_allocated( device=f"cuda:{local_rank}") - initialize_runtime_state(pipe, engine_config) - get_runtime_state().set_video_input_parameters( - height=input_config.height, - width=input_config.width, - num_frames=input_config.num_frames, - batch_size=1, - num_inference_steps=input_config.num_inference_steps, - split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, - ) - - parallelize_transformer(pipe) - if engine_config.runtime_config.use_torch_compile: torch._inductor.config.reorder_for_compute_comm_overlap = True pipe.transformer = torch.compile(pipe.transformer, @@ -299,5 +299,6 @@ def main(): # mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 544 --width 960 --num_frames 129 --enable_tiling --enable_model_cpu_offload if __name__ == "__main__": main() From dea542ee77fc90306ad2a20d16f9111f1829be59 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Fri, 20 Dec 2024 11:39:38 +0800 Subject: [PATCH 4/7] prefer enable_model_cpu_offload for hunyuan video --- examples/hunyuan_video_usp_example.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index 9895a3f..b04da1c 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -232,12 +232,12 @@ def main(): if args.enable_tiling: pipe.vae.enable_tiling( # Make it runnable on GPUs with 48GB memory - tile_sample_min_height=128, - tile_sample_stride_height=96, - tile_sample_min_width=128, - tile_sample_stride_width=96, - tile_sample_min_num_frames=32, - tile_sample_stride_num_frames=24, + # tile_sample_min_height=128, + # tile_sample_stride_height=96, + # tile_sample_min_width=128, + # tile_sample_stride_width=96, + # tile_sample_min_num_frames=32, + # tile_sample_stride_num_frames=24, ) if args.enable_slicing: @@ -298,7 +298,7 @@ def main(): get_runtime_state().destory_distributed_env() -# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling --enable_model_cpu_offload # mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 544 --width 960 --num_frames 129 --enable_tiling --enable_model_cpu_offload if __name__ == "__main__": main() From aedab06b2c11cc72b19e364bca797009039f872a Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Fri, 20 Dec 2024 13:20:45 +0800 Subject: [PATCH 5/7] scatter among height rather than time --- examples/hunyuan_video_usp_example.py | 34 ++++++++++++++++----------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index b04da1c..a276a38 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -67,6 +67,15 @@ def new_forward( timestep, encoder_attention_mask) + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1) + hidden_states = torch.chunk(hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, + get_sequence_parallel_world_size(), + dim=2)[get_sequence_parallel_rank()] + hidden_states = hidden_states.flatten(1, 3) + encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) encoder_hidden_states_indices = torch.arange( encoder_hidden_states.shape[1], @@ -81,12 +90,6 @@ def new_forward( else: get_runtime_state().split_text_embed_in_sp = True - hidden_states = torch.chunk(hidden_states, - get_classifier_free_guidance_world_size(), - dim=0)[get_classifier_free_guidance_rank()] - hidden_states = torch.chunk(hidden_states, - get_sequence_parallel_world_size(), - dim=-2)[get_sequence_parallel_rank()] encoder_hidden_states = torch.chunk( encoder_hidden_states, get_classifier_free_guidance_world_size(), @@ -100,9 +103,11 @@ def new_forward( freqs_cos, freqs_sin = image_rotary_emb def get_rotary_emb_chunk(freqs): - freqs = torch.chunk(freqs, - get_sequence_parallel_world_size(), - dim=0)[get_sequence_parallel_rank()] + dim_thw = freqs.shape[-1] + freqs = freqs.reshape(num_frames, -1, dim_thw) + freqs = freqs.chunk(get_sequence_parallel_world_size(), dim=-2)[ + get_sequence_parallel_rank()] + freqs = freqs.reshape(-1, dim_thw) return freqs freqs_cos = get_rotary_emb_chunk(freqs_cos) @@ -161,13 +166,14 @@ def custom_forward(*inputs): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) - hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) - - hidden_states = hidden_states.reshape(batch_size, + hidden_states = hidden_states.reshape(batch_size // get_classifier_free_guidance_world_size(), post_patch_num_frames, - post_patch_height, + post_patch_height // get_sequence_parallel_world_size(), post_patch_width, -1, p_t, p, p) + + hidden_states = get_sp_group().all_gather(hidden_states, dim=2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) From f91943e4a65458386afb38d85fd528c96d7e9770 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Fri, 20 Dec 2024 17:30:16 +0800 Subject: [PATCH 6/7] update hunyuanvideo performance on single L20 --- docs/performance/hunyuanvideo.md | 3 ++- examples/run_hunyuan_video_usp.sh | 43 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100755 examples/run_hunyuan_video_usp.sh diff --git a/docs/performance/hunyuanvideo.md b/docs/performance/hunyuanvideo.md index 0832bba..e214923 100644 --- a/docs/performance/hunyuanvideo.md +++ b/docs/performance/hunyuanvideo.md @@ -10,6 +10,7 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,904.08 | 925.04 | 514.08 | 337.58 | | H20 | 6,639.17 | 3,400.55 | 1,762.86 | 940.97 | +| L20 | 6,043.88 | | | | @@ -22,4 +23,4 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil | H100 | 1,735.01 | 934.09 | 645.45 | 367.02 | | H20 | 6,621.46 | 3,400.55 | 2,310.48 | 1,214.67 | - \ No newline at end of file + diff --git a/examples/run_hunyuan_video_usp.sh b/examples/run_hunyuan_video_usp.sh new file mode 100755 index 0000000..c1f8813 --- /dev/null +++ b/examples/run_hunyuan_video_usp.sh @@ -0,0 +1,43 @@ +#!/bin/bash +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# CogVideoX configuration +SCRIPT="hunyuan_video_usp_example.py" +MODEL_ID="/cfs/dit/HunyuanVideo" +# MODEL_ID="tencent/HunyuanVideo" +INFERENCE_STEP=50 + +mkdir -p ./results + +# CogVideoX specific task args +TASK_ARGS="--height 720 --width 1280 --num_frames 129" + +# CogVideoX parallel configuration +N_GPUS=8 +PARALLEL_ARGS="--ulysses_degree 4 --ring_degree 2" +# CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +ENABLE_TILING="--enable_tiling" +ENABLE_MODEL_CPU_OFFLOAD="--enable_model_cpu_offload" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "A cat walks on the grass, realistic" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$ENABLE_TILING \ +$ENABLE_MODEL_CPU_OFFLOAD \ +$COMPILE_FLAG From bc18d27864e7bcfc3a12913f0a73d097da84d7a3 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 23 Dec 2024 17:14:56 +0800 Subject: [PATCH 7/7] make hunyuan video work with more resolutions and update the performance table --- docs/performance/hunyuanvideo.md | 3 +- examples/hunyuan_video_usp_example.py | 42 ++++++++++++++++-------- xfuser/core/distributed/runtime_state.py | 9 ++++- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/docs/performance/hunyuanvideo.md b/docs/performance/hunyuanvideo.md index e214923..0e04d27 100644 --- a/docs/performance/hunyuanvideo.md +++ b/docs/performance/hunyuanvideo.md @@ -10,7 +10,7 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,904.08 | 925.04 | 514.08 | 337.58 | | H20 | 6,639.17 | 3,400.55 | 1,762.86 | 940.97 | -| L20 | 6,043.88 | | | | +| L20 | 6,043.88 | 3,271.44 | 2,080.05 | | @@ -22,5 +22,6 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,735.01 | 934.09 | 645.45 | 367.02 | | H20 | 6,621.46 | 3,400.55 | 2,310.48 | 1,214.67 | +| L20 | 6,039.08 | 3,260.62 | 2,070.96 | | diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index a276a38..0360bf2 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -1,6 +1,6 @@ # from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py import functools -from typing import Any, Dict, Union +from typing import Any, Dict, Union, Optional import logging import time @@ -8,6 +8,7 @@ from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND from diffusers.utils import export_to_video from xfuser import xFuserArgs @@ -45,8 +46,22 @@ def new_forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logging.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + batch_size, num_channels, num_frames, height, width = hidden_states.shape assert batch_size % get_classifier_free_guidance_world_size( @@ -68,13 +83,14 @@ def new_forward( encoder_attention_mask) hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1) + hidden_states = hidden_states.flatten(1, 3) + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(), dim=0)[get_classifier_free_guidance_rank()] hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), - dim=2)[get_sequence_parallel_rank()] - hidden_states = hidden_states.flatten(1, 3) + dim=-2)[get_sequence_parallel_rank()] encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) encoder_hidden_states_indices = torch.arange( @@ -103,11 +119,7 @@ def new_forward( freqs_cos, freqs_sin = image_rotary_emb def get_rotary_emb_chunk(freqs): - dim_thw = freqs.shape[-1] - freqs = freqs.reshape(num_frames, -1, dim_thw) - freqs = freqs.chunk(get_sequence_parallel_world_size(), dim=-2)[ - get_sequence_parallel_rank()] - freqs = freqs.reshape(-1, dim_thw) + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] return freqs freqs_cos = get_rotary_emb_chunk(freqs_cos) @@ -166,17 +178,21 @@ def custom_forward(*inputs): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size // get_classifier_free_guidance_world_size(), + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, - post_patch_height // get_sequence_parallel_world_size(), + post_patch_height, post_patch_width, -1, p_t, p, p) - hidden_states = get_sp_group().all_gather(hidden_states, dim=2) - hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) - hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (hidden_states, ) diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 7de0606..dd8ec07 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -103,7 +103,12 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): pipeline=pipeline, parallel_config=config.parallel_config ) self.cogvideox = False + self.hunyuan_video = False if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")): + if pipeline.__class__.__name__.startswith("CogVideoX"): + self.cogvideox = True + else: + self.hunyuan_video = True self._set_cogvideox_parameters( vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, @@ -194,7 +199,6 @@ def _set_cogvideox_parameters( self.backbone_patch_size = backbone_patch_size self.backbone_inner_dim = backbone_inner_dim self.backbone_in_channel = backbone_in_channel - self.cogvideox = True def set_patched_mode(self, patch_mode: bool): self.patch_mode = patch_mode @@ -259,6 +263,9 @@ def _video_input_size_change( self.input_config.batch_size = batch_size or self.input_config.batch_size if self.cogvideox: self._calc_cogvideox_patches_metadata() + elif self.hunyuan_video: + # TODO: implement the hunyuan video patches metadata + pass else: self._calc_patches_metadata() self._reset_recv_buffer()