Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hunyuan_video_usp_example.py #401

Merged
merged 7 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions examples/hunyuan_video_usp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py
import functools
from typing import Any, Dict, Union
import logging
import time

import torch

from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import export_to_video

from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor import xFuserHunyuanVideoAttnProcessor2_0

assert xFuserHunyuanVideoAttnProcessor2_0 is not None


def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape

assert batch_size % get_classifier_free_guidance_world_size(
) == 0, f"Cannot split dim 0 of hidden_states ({batch_size}) into {get_classifier_free_guidance_world_size()} parts."

p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p

# 1. RoPE
image_rotary_emb = self.rope(hidden_states)

# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states,
timestep,
encoder_attention_mask)

encoder_attention_mask = encoder_attention_mask[0].to(torch.bool)
encoder_hidden_states_indices = torch.arange(
encoder_hidden_states.shape[1],
device=encoder_hidden_states.device)
encoder_hidden_states_indices = encoder_hidden_states_indices[
encoder_attention_mask]
encoder_hidden_states = encoder_hidden_states[
..., encoder_hidden_states_indices, :]
if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size(
) != 0:
get_runtime_state().split_text_embed_in_sp = False
else:
get_runtime_state().split_text_embed_in_sp = True

hidden_states = torch.chunk(hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states,
get_sequence_parallel_world_size(),
dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(
encoder_hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states = torch.chunk(
encoder_hidden_states,
get_sequence_parallel_world_size(),
dim=-2)[get_sequence_parallel_rank()]

freqs_cos, freqs_sin = image_rotary_emb

def get_rotary_emb_chunk(freqs):
freqs = torch.chunk(freqs,
get_sequence_parallel_world_size(),
dim=0)[get_sequence_parallel_rank()]
return freqs

freqs_cos = get_rotary_emb_chunk(freqs_cos)
freqs_sin = get_rotary_emb_chunk(freqs_sin)
image_rotary_emb = (freqs_cos, freqs_sin)

# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):

def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None,
image_rotary_emb)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None,
image_rotary_emb)

# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.reshape(batch_size,
post_patch_num_frames,
post_patch_height,
post_patch_width, -1, p_t, p, p)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

if not return_dict:
return (hidden_states, )

return Transformer2DModelOutput(sample=hidden_states)

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserHunyuanVideoAttnProcessor2_0()


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
feifeibear marked this conversation as resolved.
Show resolved Hide resolved
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for HunyuanVideo"

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
subfolder="transformer",
torch_dtype=torch.bfloat16,
revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
transformer=transformer,
torch_dtype=torch.float16,
revision="refs/pr/18",
)

if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

if args.enable_tiling:
pipe.vae.enable_tiling(
# Make it runnable on GPUs with 48GB memory
tile_sample_min_height=128,
tile_sample_stride_height=96,
tile_sample_min_width=128,
tile_sample_stride_width=96,
tile_sample_min_num_frames=32,
tile_sample_stride_num_frames=24,
)

if args.enable_slicing:
pipe.vae.enable_slicing()

parameter_peak_memory = torch.cuda.max_memory_allocated(
device=f"cuda:{local_rank}")

initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_video_input_parameters(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)

parallelize_transformer(pipe)

if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer,
mode="max-autotune-no-cudagraphs")

# one step to warmup the torch compiler
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=1,
generator=torch.Generator(device="cuda").manual_seed(
input_config.seed),
).frames[0]

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(
input_config.seed),
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=15)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB"
)
get_runtime_state().destory_distributed_env()


# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling
if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from diffusers import DiffusionPipeline, CogVideoXPipeline
from diffusers import DiffusionPipeline
import torch.distributed

from xfuser.config.config import (
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
pipeline=pipeline, parallel_config=config.parallel_config
)
self.cogvideox = False
if isinstance(pipeline, CogVideoXPipeline):
if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")):
self._set_cogvideox_parameters(
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
Expand Down
Loading
Loading