Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CogVideoX into xDiT #211

Merged
merged 30 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d360e07
adding cogvideox; debugging
dannyxiaocn Aug 21, 2024
90e93fb
Merge branch 'main' into cogvideox
dannyxiaocn Aug 21, 2024
62ed758
debugging nccl
dannyxiaocn Aug 21, 2024
189c9ef
debugging
dannyxiaocn Aug 21, 2024
217642b
impl cogvideox patch embed, but still not fix problem
dannyxiaocn Aug 22, 2024
4421282
runnable, but oom
dannyxiaocn Aug 23, 2024
e0ef26f
just for run on A100, apply_rope is problem for new diffusers
dannyxiaocn Aug 23, 2024
20f6448
minor fix
dannyxiaocn Aug 23, 2024
c2a1e30
Merge branch 'xdit-project:main' into cogvideox
dannyxiaocn Aug 26, 2024
c7ec0f7
fix cpu offload problem but super slow
dannyxiaocn Aug 26, 2024
994421c
Merge branch 'cogvideox' of https://github.com/dannyxiaocn/xDiT into …
dannyxiaocn Aug 26, 2024
1592e93
fixing
dannyxiaocn Aug 28, 2024
8a8a84f
fixing
dannyxiaocn Aug 28, 2024
799839e
Merge branch 'xdit-project:main' into cogvideox
dannyxiaocn Aug 29, 2024
e695541
runnable in performance(memory & speed), but still not correct for re…
dannyxiaocn Aug 29, 2024
a320e4a
pos_embed parallelized, but still bugs in final generated videos
dannyxiaocn Aug 30, 2024
510d604
Merge branch 'xdit-project:main' into cogvideox
dannyxiaocn Aug 30, 2024
f11a4fb
fixing conflict in framework
dannyxiaocn Aug 30, 2024
68ec178
Delete cogvideox.sh
dannyxiaocn Aug 30, 2024
150a684
minor fix
dannyxiaocn Aug 30, 2024
8d01fe5
Merge branch 'cogvideox' of https://github.com/dannyxiaocn/xDiT into …
dannyxiaocn Aug 30, 2024
47751b4
minor fix
dannyxiaocn Aug 30, 2024
38a0b41
readme
dannyxiaocn Aug 30, 2024
f3eccdd
update diffusers version
dannyxiaocn Aug 30, 2024
564cd95
version fix
dannyxiaocn Aug 30, 2024
f8130aa
example update
dannyxiaocn Aug 30, 2024
86d6f1c
Merge branch 'cogvideox' of https://github.com/dannyxiaocn/xDiT into …
dannyxiaocn Aug 30, 2024
4aaad00
example update
dannyxiaocn Aug 30, 2024
2ea56b4
fixing
dannyxiaocn Aug 30, 2024
b18a0df
fixing
dannyxiaocn Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ profile/
xfuser.egg-info/
dist/*
latte_output.mp4
latte.sh
latte.sh
cogvideox.sh
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ The overview of xDiT is shown as follows.

<h2 id="updates">📢 Updates</h2>

* 🎉**August 30, 2024**: Support CogVideoX. The inference scripts are [examples/latte_example](examples/cogvideox_example.py).
dannyxiaocn marked this conversation as resolved.
Show resolved Hide resolved
* 🎉**August 26, 2024**: We apply torch.compile and [onediff](https://github.com/siliconflow/onediff) nexfort backend to accelerate GPU kernels speed.
* 🎉**August 9, 2024**: Support Latte sequence parallel version. The inference scripts are [examples/latte_example](examples/latte_example.py).
* 🎉**August 8, 2024**: Support Flux sequence parallel version. The inference scripts are [examples/flux_example](examples/flux_example.py).
Expand All @@ -97,6 +98,7 @@ The overview of xDiT is shown as follows.

| Model Name | CFG | SP | PipeFusion |
| --- | --- | --- | --- |
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ❎ | ❎ | ❎ |
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) | ❎ | ✔️ | ❎ |
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ❎ |
Expand Down
Empty file modified docs/developer/The_implement_design_of_xdit_framework.md
100755 → 100644
Empty file.
Empty file modified docs/developer/The_implement_design_of_xdit_framework_zh.md
100755 → 100644
Empty file.
65 changes: 65 additions & 0 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import time
import torch
import torch.distributed
from diffusers import AutoencoderKLTemporalDecoder
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
)
from diffusers.utils import export_to_video


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank


pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")
dannyxiaocn marked this conversation as resolved.
Show resolved Hide resolved


# NOTE DO NOT CALL THIS FUNCTION
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()

torch.cuda.reset_peak_memory_stats()
start_time = time.time()


output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=49,
dannyxiaocn marked this conversation as resolved.
Show resolved Hide resolved
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6,
).frames[0]


end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

if get_data_parallel_rank() == get_data_parallel_world_size() - 1:
export_to_video(output, "results/output.mp4", fps=8)

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
get_runtime_state().destory_distributed_env()


if __name__ == '__main__':
main()
Empty file modified legacy/scripts/benchmark.sh
100755 → 100644
Empty file.
2 changes: 2 additions & 0 deletions xfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
xFuserFluxPipeline,
xFuserLattePipeline,
xFuserHunyuanDiTPipeline,
xFuserCogVideoXPipeline,
)
from xfuser.config import xFuserArgs, EngineConfig

Expand All @@ -15,6 +16,7 @@
"xFuserFluxPipeline",
"xFuserLattePipeline",
"xFuserHunyuanDiTPipeline",
"xFuserCogVideoXPipeline",
"xFuserArgs",
"EngineConfig",
]
2 changes: 1 addition & 1 deletion xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def to_dict(self):
class InputConfig:
height: int = 1024
width: int = 1024
video_length: int = 16
num_frames: int = 16
use_resolution_binning: bool = (True,)
batch_size: Optional[int] = None
prompt: Union[str, List[str]] = ""
Expand Down
171 changes: 155 additions & 16 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers import DiffusionPipeline, CogVideoXPipeline
import torch.distributed

from xfuser.config.config import (
Expand Down Expand Up @@ -83,6 +83,8 @@ class DiTRuntimeState(RuntimeState):
patch_mode: bool
pipeline_patch_idx: int
vae_scale_factor: int
vae_scale_factor_spatial: int
vae_scale_factor_temporal: int
backbone_patch_size: int
pp_patches_height: Optional[List[int]]
pp_patches_start_idx_local: Optional[List[int]]
Expand All @@ -102,13 +104,24 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
self._check_model_and_parallel_config(
pipeline=pipeline, parallel_config=config.parallel_config
)
self._set_model_parameters(
vae_scale_factor=pipeline.vae_scale_factor,
backbone_patch_size=pipeline.transformer.config.patch_size,
backbone_in_channel=pipeline.transformer.config.in_channels,
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
self.cogvideox = False
if isinstance(pipeline, CogVideoXPipeline):
self._set_cogvideox_parameters(
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
backbone_patch_size=pipeline.transformer.config.patch_size,
backbone_in_channel=pipeline.transformer.config.in_channels,
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
else:
self._set_model_parameters(
vae_scale_factor=pipeline.vae_scale_factor,
backbone_patch_size=pipeline.transformer.config.patch_size,
backbone_in_channel=pipeline.transformer.config.in_channels,
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
self.pipeline_comm_extra_tensors_info = []

def set_input_parameters(
Expand Down Expand Up @@ -140,7 +153,7 @@ def set_video_input_parameters(
self,
height: Optional[int] = None,
width: Optional[int] = None,
video_length: Optional[int] = None,
num_frames: Optional[int] = None,
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
Expand All @@ -156,12 +169,27 @@ def set_video_input_parameters(
if (
(height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (video_length and self.input_config.video_length != video_length)
or (num_frames and self.input_config.num_frames != num_frames)
or (batch_size and self.input_config.batch_size != batch_size)
):
self._video_input_size_change(height, width, video_length, batch_size)
self._video_input_size_change(height, width, num_frames, batch_size)

self.ready = True

def _set_cogvideox_parameters(
self,
vae_scale_factor_spatial: int,
vae_scale_factor_temporal: int,
backbone_patch_size: int,
backbone_inner_dim: int,
backbone_in_channel: int,
):
self.vae_scale_factor_spatial = vae_scale_factor_spatial
self.vae_scale_factor_temporal = vae_scale_factor_temporal
self.backbone_patch_size = backbone_patch_size
self.backbone_inner_dim = backbone_inner_dim
self.backbone_in_channel = backbone_in_channel
self.cogvideox = True

def set_patched_mode(self, patch_mode: bool):
self.patch_mode = patch_mode
Expand Down Expand Up @@ -217,16 +245,19 @@ def _video_input_size_change(
self,
height: Optional[int] = None,
width: Optional[int] = None,
video_length: Optional[int] = None,
num_frames: Optional[int] = None,
batch_size: Optional[int] = None,
):
self.input_config.height = height or self.input_config.height
self.input_config.width = width or self.input_config.width
self.input_config.video_length = video_length or self.input_config.video_length
self.input_config.num_frames = num_frames or self.input_config.num_frames
self.input_config.batch_size = batch_size or self.input_config.batch_size
self._calc_patches_metadata()
if self.cogvideox:
self._calc_cogvideox_patches_metadata
else:
self._calc_patches_metadata()
self._reset_recv_buffer()

def _calc_patches_metadata(self):
num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
Expand Down Expand Up @@ -332,6 +363,114 @@ def _calc_patches_metadata(self):
pp_patches_token_start_end_idx_global
)
self.pp_patches_token_num = pp_patches_token_num

def _calc_cogvideox_patches_metadata(self):

num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
vae_scale_factor_spatial = self.vae_scale_factor_spatial
latents_height = self.input_config.height // vae_scale_factor_spatial
latents_width = self.input_config.width // vae_scale_factor_spatial
latents_frames = (self.input_config.num_frames - 1) // self.vae_scale_factor_temporal + 1

if latents_height % num_sp_patches != 0:
raise ValueError(
"The height of the input is not divisible by the number of sequence parallel devices"
)

self.num_pipeline_patch = self.parallel_config.pp_config.num_pipeline_patch
# Pipeline patches
pipeline_patches_height = (
latents_height + self.num_pipeline_patch - 1
) // self.num_pipeline_patch
# make sure pipeline_patches_height is a multiple of (num_sp_patches * patch_size)
pipeline_patches_height = (
(pipeline_patches_height + (num_sp_patches * patch_size) - 1)
// (patch_size * num_sp_patches)
) * (patch_size * num_sp_patches)
# get the number of pipeline that matches patch height requirements
num_pipeline_patch = (
latents_height + pipeline_patches_height - 1
) // pipeline_patches_height
if num_pipeline_patch != self.num_pipeline_patch:
logger.warning(
f"Pipeline patches num changed from "
f"{self.num_pipeline_patch} to {num_pipeline_patch} due "
f"to input size and parallelisation requirements"
)
pipeline_patches_height_list = [
pipeline_patches_height for _ in range(num_pipeline_patch - 1)
]
the_last_pp_patch_height = latents_height - pipeline_patches_height * (
num_pipeline_patch - 1
)
if the_last_pp_patch_height % (patch_size * num_sp_patches) != 0:
raise ValueError(
f"The height of the last pipeline patch is {the_last_pp_patch_height}, "
f"which is not a multiple of (patch_size * num_sp_patches): "
f"{patch_size} * {num_sp_patches}. Please try to adjust 'num_pipeline_patches "
f"or sp_degree argument so that the condition are met "
)
pipeline_patches_height_list.append(the_last_pp_patch_height)

# Sequence parallel patches
# len: sp_degree * num_pipeline_patches
flatten_patches_height = [
pp_patch_height // num_sp_patches
for _ in range(num_sp_patches)
for pp_patch_height in pipeline_patches_height_list
]
flatten_patches_start_idx = [0] + [
sum(flatten_patches_height[:i])
for i in range(1, len(flatten_patches_height) + 1)
]
pp_sp_patches_height = [
flatten_patches_height[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches
]
for pp_patch_idx in range(num_pipeline_patch)
]
pp_sp_patches_start_idx = [
flatten_patches_start_idx[
pp_patch_idx * num_sp_patches : (pp_patch_idx + 1) * num_sp_patches + 1
]
for pp_patch_idx in range(num_pipeline_patch)
]

pp_patches_height = [
sp_patches_height[sp_patch_idx]
for sp_patches_height in pp_sp_patches_height
]
pp_patches_start_idx_local = [0] + [
sum(pp_patches_height[:i]) for i in range(1, len(pp_patches_height) + 1)
]
pp_patches_start_end_idx_global = [
sp_patches_start_idx[sp_patch_idx : sp_patch_idx + 2]
for sp_patches_start_idx in pp_sp_patches_start_idx
]
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size) * (start_idx // patch_size) * latents_frames,
(latents_width // patch_size) * (end_idx // patch_size) * latents_frames,
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
pp_patches_token_num = [
end - start for start, end in pp_patches_token_start_end_idx_global
]
pp_patches_token_start_idx_local = [
sum(pp_patches_token_num[:i]) for i in range(len(pp_patches_token_num) + 1)
]
self.num_pipeline_patch = num_pipeline_patch
self.pp_patches_height = pp_patches_height
self.pp_patches_start_idx_local = pp_patches_start_idx_local
self.pp_patches_start_end_idx_global = pp_patches_start_end_idx_global
self.pp_patches_token_start_idx_local = pp_patches_token_start_idx_local
self.pp_patches_token_start_end_idx_global = (
pp_patches_token_start_end_idx_global
)
self.pp_patches_token_num = pp_patches_token_num

def _reset_recv_buffer(self):
get_pp_group().reset_buffer()
Expand Down Expand Up @@ -382,4 +521,4 @@ def initialize_runtime_state(pipeline: DiffusionPipeline, engine_config: EngineC
"Runtime state is already initialized, reinitializing with pipeline..."
)
if hasattr(pipeline, "transformer"):
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)
Loading