Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix]: image saving bugs #251

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from xfuser import xFuserCogVideoXPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_rank,
get_world_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
from diffusers.utils import export_to_video

Expand All @@ -19,22 +20,21 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank


pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
else:
pipe = pipe.to(f"cuda:{local_rank}")

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
Expand All @@ -44,20 +44,18 @@ def main():
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6,
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
if get_data_parallel_rank() == get_data_parallel_world_size() - 1:

if is_dp_last_group():
export_to_video(output, "results/output.mp4", fps=8)

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
8 changes: 3 additions & 5 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():
else:
pipe = pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config, max_sequence_length=256)
pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand All @@ -57,10 +57,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
Expand Down
7 changes: 3 additions & 4 deletions examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
is_dp_last_group,
get_data_parallel_world_size,
get_runtime_state,
get_data_parallel_rank,
)


Expand Down Expand Up @@ -46,10 +47,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
Expand Down
3 changes: 2 additions & 1 deletion examples/latte_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_data_parallel_rank,
get_data_parallel_world_size,
get_runtime_state,
is_dp_last_group,
)
import imageio

Expand Down Expand Up @@ -53,7 +54,7 @@ def main():
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if get_data_parallel_rank() == get_data_parallel_world_size() - 1:
if is_dp_last_group():
videos = output.frames.cpu()
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
Expand Down
7 changes: 3 additions & 4 deletions examples/pixartalpha_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
is_dp_last_group,
get_data_parallel_world_size,
get_runtime_state,
get_data_parallel_rank,
)


Expand Down Expand Up @@ -46,10 +47,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
Expand Down
7 changes: 3 additions & 4 deletions examples/pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
is_dp_last_group,
get_data_parallel_world_size,
get_runtime_state,
get_data_parallel_rank,
)


Expand Down Expand Up @@ -46,10 +47,8 @@ def main():
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
global_rank = get_world_group().rank
dp_group_world_size = get_data_parallel_world_size()
dp_group_index = global_rank // dp_group_world_size
num_dp_groups = engine_config.parallel_config.dp_degree
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if not os.path.exists("results"):
Expand Down
27 changes: 5 additions & 22 deletions tests/parallel_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from xfuser.parallel import xDiTParallel

import time
import os
import torch
from diffusers import StableDiffusion3Pipeline
from diffusers import StableDiffusion3Pipeline, FluxPipeline

from xfuser import xFuserArgs
from xfuser.parallel import xDiTParallel
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_world_size,
get_runtime_state,
)
from xfuser.core.distributed import get_world_group


def main():
Expand All @@ -29,8 +21,6 @@ def main():

paralleler = xDiTParallel(pipe, engine_config, input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
paralleler(
height=input_config.height,
width=input_config.height,
Expand All @@ -39,15 +29,8 @@ def main():
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

paralleler.save("results/", "stable_diffusion_3")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()
if input_config.output_type == "pil":
paralleler.save("results", "stable_diffusion_3")


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion xfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
xFuserCogVideoXPipeline,
)
from xfuser.config import xFuserArgs, EngineConfig
from xfuser.parallel import xDiTParallel

__all__ = [
"xFuserPixArtAlphaPipeline",
Expand All @@ -19,4 +20,5 @@
"xFuserCogVideoXPipeline",
"xFuserArgs",
"EngineConfig",
]
"xDiTParallel",
]
8 changes: 8 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class xFuserArgs:
width: int = 1024
num_frames: int = 49
num_inference_steps: int = 20
max_sequence_length: int = 256
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
no_use_resolution_binning: bool = False
Expand Down Expand Up @@ -218,6 +219,12 @@ def add_cli_args(parser: FlexibleArgumentParser):
default=20,
help="Number of inference steps.",
)
input_group.add_argument(
"--max_sequence_length",
type=int,
default=256,
help="Max sequencen length of prompt",
)
runtime_group.add_argument(
"--seed", type=int, default=42, help="Random seed for operations."
)
Expand Down Expand Up @@ -302,6 +309,7 @@ def create_config(
prompt=self.prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=self.num_inference_steps,
max_sequence_length=self.max_sequence_length,
seed=self.seed,
output_type=self.output_type,
)
Expand Down
1 change: 1 addition & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class InputConfig:
prompt: Union[str, List[str]] = ""
negative_prompt: Union[str, List[str]] = ""
num_inference_steps: int = 20
max_sequence_length: int = 256
seed: int = 42
output_type: str = "pil"

Expand Down
28 changes: 15 additions & 13 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,6 @@ class DiTRuntimeState(RuntimeState):
pp_patches_token_start_idx_local: Optional[List[int]]
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
# Storing the shape of a tensor that is not latent but requires pp communication
# torch.Size: size of tensor
# int: number of recv buffer it needs
pipeline_comm_extra_tensors_info: List[Tuple[str, List[int], int]]

def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
Expand All @@ -122,7 +118,6 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
backbone_inner_dim=pipeline.transformer.config.num_attention_heads
* pipeline.transformer.config.attention_head_dim,
)
self.pipeline_comm_extra_tensors_info = []

def set_input_parameters(
self,
Expand All @@ -144,6 +139,7 @@ def set_input_parameters(
(height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (batch_size and self.input_config.batch_size != batch_size)
or not self.ready
):
self._input_size_change(height, width, batch_size)

Expand Down Expand Up @@ -175,7 +171,7 @@ def set_video_input_parameters(
self._video_input_size_change(height, width, num_frames, batch_size)

self.ready = True

def _set_cogvideox_parameters(
self,
vae_scale_factor_spatial: int,
Expand Down Expand Up @@ -257,7 +253,7 @@ def _video_input_size_change(
else:
self._calc_patches_metadata()
self._reset_recv_buffer()

def _calc_patches_metadata(self):
num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
Expand Down Expand Up @@ -363,16 +359,18 @@ def _calc_patches_metadata(self):
pp_patches_token_start_end_idx_global
)
self.pp_patches_token_num = pp_patches_token_num

def _calc_cogvideox_patches_metadata(self):

num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
vae_scale_factor_spatial = self.vae_scale_factor_spatial
latents_height = self.input_config.height // vae_scale_factor_spatial
latents_width = self.input_config.width // vae_scale_factor_spatial
latents_frames = (self.input_config.num_frames - 1) // self.vae_scale_factor_temporal + 1
latents_frames = (
self.input_config.num_frames - 1
) // self.vae_scale_factor_temporal + 1

if latents_height % num_sp_patches != 0:
raise ValueError(
Expand Down Expand Up @@ -451,8 +449,12 @@ def _calc_cogvideox_patches_metadata(self):
]
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size) * (start_idx // patch_size) * latents_frames,
(latents_width // patch_size) * (end_idx // patch_size) * latents_frames,
(latents_width // patch_size)
* (start_idx // patch_size)
* latents_frames,
(latents_width // patch_size)
* (end_idx // patch_size)
* latents_frames,
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
Expand Down Expand Up @@ -521,4 +523,4 @@ def initialize_runtime_state(pipeline: DiffusionPipeline, engine_config: EngineC
"Runtime state is already initialized, reinitializing with pipeline..."
)
if hasattr(pipeline, "transformer"):
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)
12 changes: 0 additions & 12 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,6 @@ def _convert_vae(
def __call__(self):
pass

def _set_extra_comm_tensor_for_pipeline(
self, extra_tensors_shape_dict: List[Tuple[str, List[int], int]] = []
):
if (
get_runtime_state().pipeline_comm_extra_tensors_info
== extra_tensors_shape_dict
):
return
for name, shape, cnt in extra_tensors_shape_dict:
get_pp_group().set_extra_tensors_recv_buffer(name, shape, cnt)
get_runtime_state().pipeline_comm_extra_tensors_info = extra_tensors_shape_dict

def _init_sync_pipeline(self, latents: torch.Tensor):
get_runtime_state().set_patched_mode(patch_mode=False)

Expand Down
Loading