From 5f8610016264a8dd7a20c197771043536fa9609e Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Tue, 3 Sep 2024 20:54:00 +0800 Subject: [PATCH] fix: bugs --- examples/cogvideox_example.py | 36 +++--- examples/flux_example.py | 8 +- examples/hunyuandit_example.py | 7 +- examples/latte_example.py | 3 +- examples/pixartalpha_example.py | 7 +- examples/pixartsigma_example.py | 7 +- tests/parallel_test.py | 27 +---- xfuser/__init__.py | 4 +- xfuser/config/args.py | 8 ++ xfuser/config/config.py | 1 + xfuser/core/distributed/runtime_state.py | 28 ++--- .../model_executor/pipelines/base_pipeline.py | 12 -- .../pipelines/pipeline_cogvideox.py | 105 ++++++++++++------ .../model_executor/pipelines/pipeline_flux.py | 10 +- .../pipelines/pipeline_hunyuandit.py | 1 + .../pipelines/pipeline_latte.py | 8 +- .../pipelines/pipeline_stable_diffusion_3.py | 1 + xfuser/parallel.py | 17 +-- 18 files changed, 154 insertions(+), 136 deletions(-) diff --git a/examples/cogvideox_example.py b/examples/cogvideox_example.py index 48323b2c..8355c5c9 100644 --- a/examples/cogvideox_example.py +++ b/examples/cogvideox_example.py @@ -5,10 +5,11 @@ from xfuser import xFuserCogVideoXPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser from xfuser.core.distributed import ( - get_world_group, - get_data_parallel_rank, + get_world_group, + get_data_parallel_rank, get_data_parallel_world_size, get_runtime_state, + is_dp_last_group, ) from diffusers.utils import export_to_video @@ -19,22 +20,21 @@ def main(): engine_args = xFuserArgs.from_cli_args(args) engine_config, input_config = engine_args.create_config() local_rank = get_world_group().local_rank - pipe = xFuserCogVideoXPipeline.from_pretrained( - pretrained_model_name_or_path=engine_config.model_config.model, - engine_config=engine_config, - torch_dtype=torch.bfloat16, - ) - if args.enable_sequential_cpu_offload: + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: pipe.enable_model_cpu_offload(gpu_id=local_rank) pipe.vae.enable_tiling() - else: + else: pipe = pipe.to(f"cuda:{local_rank}") torch.cuda.reset_peak_memory_stats() start_time = time.time() - + output = pipe( height=input_config.height, width=input_config.width, @@ -44,20 +44,18 @@ def main(): generator=torch.Generator(device="cuda").manual_seed(input_config.seed), guidance_scale=6, ).frames[0] - + end_time = time.time() elapsed_time = end_time - start_time peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - - if get_data_parallel_rank() == get_data_parallel_world_size() - 1: + + if is_dp_last_group(): export_to_video(output, "results/output.mp4", fps=8) - + if get_world_group().rank == get_world_group().world_size - 1: - print( - f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB" - ) + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") get_runtime_state().destory_distributed_env() -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/examples/flux_example.py b/examples/flux_example.py index b55aa1e3..5c7c0435 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -32,7 +32,7 @@ def main(): else: pipe = pipe.to(f"cuda:{local_rank}") - pipe.prepare_run(input_config, max_sequence_length=256) + pipe.prepare_run(input_config) torch.cuda.reset_peak_memory_stats() start_time = time.time() @@ -57,10 +57,8 @@ def main(): f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups if is_dp_last_group(): for i, image in enumerate(output.images): diff --git a/examples/hunyuandit_example.py b/examples/hunyuandit_example.py index a9ccf3d7..c6f56d84 100644 --- a/examples/hunyuandit_example.py +++ b/examples/hunyuandit_example.py @@ -9,6 +9,7 @@ is_dp_last_group, get_data_parallel_world_size, get_runtime_state, + get_data_parallel_rank, ) @@ -46,10 +47,8 @@ def main(): f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups if is_dp_last_group(): if not os.path.exists("results"): diff --git a/examples/latte_example.py b/examples/latte_example.py index a5a0198a..212879e9 100644 --- a/examples/latte_example.py +++ b/examples/latte_example.py @@ -9,6 +9,7 @@ get_data_parallel_rank, get_data_parallel_world_size, get_runtime_state, + is_dp_last_group, ) import imageio @@ -53,7 +54,7 @@ def main(): f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) - if get_data_parallel_rank() == get_data_parallel_world_size() - 1: + if is_dp_last_group(): videos = output.frames.cpu() global_rank = get_world_group().rank dp_group_world_size = get_data_parallel_world_size() diff --git a/examples/pixartalpha_example.py b/examples/pixartalpha_example.py index 28086129..041c12be 100644 --- a/examples/pixartalpha_example.py +++ b/examples/pixartalpha_example.py @@ -9,6 +9,7 @@ is_dp_last_group, get_data_parallel_world_size, get_runtime_state, + get_data_parallel_rank, ) @@ -46,10 +47,8 @@ def main(): f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups if is_dp_last_group(): if not os.path.exists("results"): diff --git a/examples/pixartsigma_example.py b/examples/pixartsigma_example.py index 2d1af80a..234b39ab 100644 --- a/examples/pixartsigma_example.py +++ b/examples/pixartsigma_example.py @@ -9,6 +9,7 @@ is_dp_last_group, get_data_parallel_world_size, get_runtime_state, + get_data_parallel_rank, ) @@ -46,10 +47,8 @@ def main(): f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" ) if input_config.output_type == "pil": - global_rank = get_world_group().rank - dp_group_world_size = get_data_parallel_world_size() - dp_group_index = global_rank // dp_group_world_size - num_dp_groups = engine_config.parallel_config.dp_degree + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups if is_dp_last_group(): if not os.path.exists("results"): diff --git a/tests/parallel_test.py b/tests/parallel_test.py index 7992c74f..78be80f0 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -1,18 +1,10 @@ -from xfuser.parallel import xDiTParallel - -import time -import os import torch -from diffusers import StableDiffusion3Pipeline +from diffusers import StableDiffusion3Pipeline, FluxPipeline from xfuser import xFuserArgs +from xfuser.parallel import xDiTParallel from xfuser.config import FlexibleArgumentParser -from xfuser.core.distributed import ( - get_world_group, - is_dp_last_group, - get_data_parallel_world_size, - get_runtime_state, -) +from xfuser.core.distributed import get_world_group def main(): @@ -29,8 +21,6 @@ def main(): paralleler = xDiTParallel(pipe, engine_config, input_config) - torch.cuda.reset_peak_memory_stats() - start_time = time.time() paralleler( height=input_config.height, width=input_config.height, @@ -39,15 +29,8 @@ def main(): output_type=input_config.output_type, generator=torch.Generator(device="cuda").manual_seed(input_config.seed), ) - end_time = time.time() - elapsed_time = end_time - start_time - peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") - - paralleler.save("results/", "stable_diffusion_3") - - if get_world_group().rank == get_world_group().world_size - 1: - print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") - get_runtime_state().destory_distributed_env() + if input_config.output_type == "pil": + paralleler.save("results", "stable_diffusion_3") if __name__ == "__main__": diff --git a/xfuser/__init__.py b/xfuser/__init__.py index eb0e9f95..9df092de 100644 --- a/xfuser/__init__.py +++ b/xfuser/__init__.py @@ -8,6 +8,7 @@ xFuserCogVideoXPipeline, ) from xfuser.config import xFuserArgs, EngineConfig +from xfuser.parallel import xDiTParallel __all__ = [ "xFuserPixArtAlphaPipeline", @@ -19,4 +20,5 @@ "xFuserCogVideoXPipeline", "xFuserArgs", "EngineConfig", -] \ No newline at end of file + "xDiTParallel", +] diff --git a/xfuser/config/args.py b/xfuser/config/args.py index bd63c183..7ca8b551 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -87,6 +87,7 @@ class xFuserArgs: width: int = 1024 num_frames: int = 49 num_inference_steps: int = 20 + max_sequence_length: int = 256 prompt: Union[str, List[str]] = "" negative_prompt: Union[str, List[str]] = "" no_use_resolution_binning: bool = False @@ -218,6 +219,12 @@ def add_cli_args(parser: FlexibleArgumentParser): default=20, help="Number of inference steps.", ) + input_group.add_argument( + "--max_sequence_length", + type=int, + default=256, + help="Max sequencen length of prompt", + ) runtime_group.add_argument( "--seed", type=int, default=42, help="Random seed for operations." ) @@ -302,6 +309,7 @@ def create_config( prompt=self.prompt, negative_prompt=self.negative_prompt, num_inference_steps=self.num_inference_steps, + max_sequence_length=self.max_sequence_length, seed=self.seed, output_type=self.output_type, ) diff --git a/xfuser/config/config.py b/xfuser/config/config.py index e2d38edf..85064bce 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -222,6 +222,7 @@ class InputConfig: prompt: Union[str, List[str]] = "" negative_prompt: Union[str, List[str]] = "" num_inference_steps: int = 20 + max_sequence_length: int = 256 seed: int = 42 output_type: str = "pil" diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 4a92b61c..a2c9c631 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -92,10 +92,6 @@ class DiTRuntimeState(RuntimeState): pp_patches_token_start_idx_local: Optional[List[int]] pp_patches_token_start_end_idx_global: Optional[List[List[int]]] pp_patches_token_num: Optional[List[int]] - # Storing the shape of a tensor that is not latent but requires pp communication - # torch.Size: size of tensor - # int: number of recv buffer it needs - pipeline_comm_extra_tensors_info: List[Tuple[str, List[int], int]] def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): super().__init__(config) @@ -122,7 +118,6 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): backbone_inner_dim=pipeline.transformer.config.num_attention_heads * pipeline.transformer.config.attention_head_dim, ) - self.pipeline_comm_extra_tensors_info = [] def set_input_parameters( self, @@ -144,6 +139,7 @@ def set_input_parameters( (height and self.input_config.height != height) or (width and self.input_config.width != width) or (batch_size and self.input_config.batch_size != batch_size) + or not self.ready ): self._input_size_change(height, width, batch_size) @@ -175,7 +171,7 @@ def set_video_input_parameters( self._video_input_size_change(height, width, num_frames, batch_size) self.ready = True - + def _set_cogvideox_parameters( self, vae_scale_factor_spatial: int, @@ -257,7 +253,7 @@ def _video_input_size_change( else: self._calc_patches_metadata() self._reset_recv_buffer() - + def _calc_patches_metadata(self): num_sp_patches = get_sequence_parallel_world_size() sp_patch_idx = get_sequence_parallel_rank() @@ -363,16 +359,18 @@ def _calc_patches_metadata(self): pp_patches_token_start_end_idx_global ) self.pp_patches_token_num = pp_patches_token_num - + def _calc_cogvideox_patches_metadata(self): - + num_sp_patches = get_sequence_parallel_world_size() sp_patch_idx = get_sequence_parallel_rank() patch_size = self.backbone_patch_size vae_scale_factor_spatial = self.vae_scale_factor_spatial latents_height = self.input_config.height // vae_scale_factor_spatial latents_width = self.input_config.width // vae_scale_factor_spatial - latents_frames = (self.input_config.num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents_frames = ( + self.input_config.num_frames - 1 + ) // self.vae_scale_factor_temporal + 1 if latents_height % num_sp_patches != 0: raise ValueError( @@ -451,8 +449,12 @@ def _calc_cogvideox_patches_metadata(self): ] pp_patches_token_start_end_idx_global = [ [ - (latents_width // patch_size) * (start_idx // patch_size) * latents_frames, - (latents_width // patch_size) * (end_idx // patch_size) * latents_frames, + (latents_width // patch_size) + * (start_idx // patch_size) + * latents_frames, + (latents_width // patch_size) + * (end_idx // patch_size) + * latents_frames, ] for start_idx, end_idx in pp_patches_start_end_idx_global ] @@ -521,4 +523,4 @@ def initialize_runtime_state(pipeline: DiffusionPipeline, engine_config: EngineC "Runtime state is already initialized, reinitializing with pipeline..." ) if hasattr(pipeline, "transformer"): - _RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config) \ No newline at end of file + _RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config) diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index a0990efa..0bcb5219 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -302,18 +302,6 @@ def _convert_vae( def __call__(self): pass - def _set_extra_comm_tensor_for_pipeline( - self, extra_tensors_shape_dict: List[Tuple[str, List[int], int]] = [] - ): - if ( - get_runtime_state().pipeline_comm_extra_tensors_info - == extra_tensors_shape_dict - ): - return - for name, shape, cnt in extra_tensors_shape_dict: - get_pp_group().set_extra_tensors_recv_buffer(name, shape, cnt) - get_runtime_state().pipeline_comm_extra_tensors_info = extra_tensors_shape_dict - def _init_sync_pipeline(self, latents: torch.Tensor): get_runtime_state().set_patched_mode(patch_mode=False) diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index b4491aa1..cd384311 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -4,7 +4,10 @@ import torch import torch.distributed from diffusers import CogVideoXPipeline -from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipelineOutput, retrieve_timesteps +from diffusers.pipelines.cogvideo.pipeline_cogvideox import ( + CogVideoXPipelineOutput, + retrieve_timesteps, +) from diffusers.schedulers import CogVideoXDPMScheduler, CogVideoXDDIMScheduler from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.utils import deprecate @@ -24,7 +27,7 @@ get_world_group, get_cfg_group, get_sp_group, - get_runtime_state, + get_runtime_state, initialize_runtime_state, get_data_parallel_rank, ) @@ -32,6 +35,7 @@ from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister + @xFuserPipelineWrapperRegister.register(CogVideoXPipeline) class xFuserCogVideoXPipeline(xFuserPipelineBaseWrapper): @@ -70,10 +74,15 @@ def __call__( output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, + **kwargs, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -159,11 +168,15 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + height = ( + height + or self.transformer.config.sample_size * self.vae_scale_factor_spatial + ) + width = ( + width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + ) num_videos_per_prompt = 1 - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -186,14 +199,12 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - - + get_runtime_state().set_video_input_parameters( height=height, width=width, @@ -201,7 +212,6 @@ def __call__( batch_size=batch_size, num_inference_steps=num_inference_steps, ) - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -216,10 +226,11 @@ def __call__( ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents. @@ -235,14 +246,15 @@ def __call__( generator, latents, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: latents = self._init_video_sync_pipeline(latents) # for DPM-solver++ @@ -251,12 +263,16 @@ def __call__( if self.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - + # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, @@ -265,20 +281,34 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred.float() - + # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ( + 1 + - math.cos( + math.pi + * ( + (num_inference_steps - t.item()) + / num_inference_steps + ) + ** 5.0 + ) + ) + / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) - # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler.module, CogVideoXDPMScheduler): - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] else: latents, old_pred_original_sample = self.scheduler.step( noise_pred, @@ -290,7 +320,6 @@ def __call__( return_dict=False, ) latents = latents.to(prompt_embeds.dtype) - # call the callback, if provided if callback_on_step_end is not None: @@ -301,11 +330,15 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() - + if get_sequence_parallel_world_size() > 1: sp_degree = get_sequence_parallel_world_size() sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) @@ -315,21 +348,23 @@ def __call__( sp_latents_list[sp_patch_idx][ :, :, - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] latents = torch.cat(latents_list, dim=-2) - - if get_data_parallel_rank() == get_data_parallel_world_size() - 1: + + if is_dp_last_group(): if not (output_type == "latents" or output_type == "latent"): video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = self.video_processor.postprocess_video( + video=video, output_type=output_type + ) else: video = latents - # Offload all models self.maybe_free_model_hooks() @@ -345,4 +380,4 @@ def interrupt(self): @property def guidance_scale(self): - return self._guidance_scale \ No newline at end of file + return self._guidance_scale diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 26526aa4..43751a94 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -29,10 +29,9 @@ get_pp_group, get_sequence_parallel_world_size, get_sp_group, - get_data_parallel_rank, - get_data_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, + is_dp_last_group, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -61,7 +60,6 @@ def from_pretrained( def prepare_run( self, input_config: InputConfig, - max_sequence_length, steps: int = 3, sync_steps: int = 1, ): @@ -74,7 +72,7 @@ def prepare_run( prompt=prompt, num_inference_steps=steps, output_type="latent", - max_sequence_length=max_sequence_length, + max_sequence_length=input_config.max_sequence_length, generator=torch.Generator(device="cuda").manual_seed(42), ) get_runtime_state().runtime_config.warmup_steps = warmup_steps @@ -121,6 +119,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -218,7 +217,6 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - device = self._execution_device #! ---------------------------------------- ADDED BELOW ---------------------------------------- @@ -311,7 +309,7 @@ def __call__( sync_only=True, ) - if get_data_parallel_rank() == get_data_parallel_world_size() - 1: + if is_dp_last_group(): if output_type == "latent": image = latents diff --git a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py index 03277fc2..7d8c784d 100644 --- a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py +++ b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py @@ -111,6 +111,7 @@ def __call__( target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, + **kwargs, ): r""" The call function to the pipeline for generation with HunyuanDiT. diff --git a/xfuser/model_executor/pipelines/pipeline_latte.py b/xfuser/model_executor/pipelines/pipeline_latte.py index 191ff52f..c773f870 100644 --- a/xfuser/model_executor/pipelines/pipeline_latte.py +++ b/xfuser/model_executor/pipelines/pipeline_latte.py @@ -35,6 +35,7 @@ get_sp_group, get_runtime_state, initialize_runtime_state, + is_dp_last_group, ) from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper @@ -90,6 +91,7 @@ def __call__( enable_temporal_attentions: bool = True, decode_chunk_size: Optional[int] = None, num_pipeline_warmup_steps: Optional[int] = 3, + **kwargs, ) -> Union[LattePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -258,9 +260,9 @@ def __call__( ) with self.progress_bar(total=num_inference_steps) as progress_bar: - + latents = self._init_video_sync_pipeline(latents) - + for i, t in enumerate(timesteps): if self.interrupt: continue @@ -356,7 +358,7 @@ def __call__( ] latents = torch.cat(latents_list, dim=-2) - if get_data_parallel_rank() == get_data_parallel_world_size() - 1: + if is_dp_last_group(): if not (output_type == "latents" or output_type == "latent"): video = self.decode_latents(latents, num_frames, decode_chunk_size=14) video = self.video_processor.postprocess_video( diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 396797f4..0e80cbc0 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -137,6 +137,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" Function invoked when calling the pipeline for generation. diff --git a/xfuser/parallel.py b/xfuser/parallel.py index 4d2bbf89..51eed42e 100644 --- a/xfuser/parallel.py +++ b/xfuser/parallel.py @@ -1,6 +1,5 @@ import os -from typing import Any, Type, Union -from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from pathlib import Path from xfuser.config.config import InputConfig from xfuser.core.distributed import ( @@ -13,6 +12,7 @@ get_data_parallel_world_size, is_dp_last_group, ) +from xfuser.core.distributed.runtime_state import get_runtime_state from xfuser.logger import init_logger from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper from xfuser.model_executor.pipelines.register import xFuserPipelineWrapperRegister @@ -42,10 +42,13 @@ def save(self, directory: str, prefix: str): f"ulysses{self.config.parallel_config.ulysses_degree}_ring{self.config.parallel_config.ring_degree}_" f"pp{self.config.parallel_config.pp_degree}_patch{self.config.parallel_config.pp_config.num_pipeline_patch}" ) - prefix = f"{directory}/{prefix}_result_{parallel_info}_dprank{dp_rank}" if is_dp_last_group(): - if not os.path.exists("results"): - os.mkdir("results") + path = Path(f"{directory}") + path.mkdir(mode=755, parents=True, exist_ok=True) + path = path / f"{prefix}_result_{parallel_info}_dprank{dp_rank}" for i, image in enumerate(self.result.images): - image.save(f"{prefix}_image{i}.png") - print(f"{prefix}_image{i}.png") + image.save(f"{str(path)}_image{i}.png") + print(f"{str(path)}_image{i}.png") + + def __del__(self): + get_runtime_state().destory_distributed_env()