From c63d09cde1b4d5710b4050e806e6e628a202480f Mon Sep 17 00:00:00 2001 From: TianYu GUO Date: Fri, 11 Oct 2024 10:36:56 +0800 Subject: [PATCH] Fix compatibility issue between parallel vae and naive forward; Enable warmup for vae (#300) --- .../model_executor/pipelines/base_pipeline.py | 20 ++++++++++--------- .../model_executor/pipelines/pipeline_flux.py | 1 - .../pipelines/pipeline_stable_diffusion_3.py | 1 - 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 34c4310c..c71bafd5 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -89,7 +89,7 @@ def __init__( if scheduler is not None: pipeline.scheduler = self._convert_scheduler(scheduler) - if vae is not None and engine_config.runtime_config.use_parallel_vae: + if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward(): pipeline.vae = self._convert_vae(vae) super().__init__(module=pipeline) @@ -167,17 +167,20 @@ def data_parallel_fn(self, *args, **kwargs): return data_parallel_fn - @staticmethod - def check_to_use_naive_forward(func): - @wraps(func) - def check_naive_forward_fn(self, *args, **kwargs): - if ( + def use_naive_forward(self): + return ( get_pipeline_parallel_world_size() == 1 and get_classifier_free_guidance_world_size() == 1 and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 and get_fast_attn_enable() == False - ): + ) + + @staticmethod + def check_to_use_naive_forward(func): + @wraps(func) + def check_naive_forward_fn(self, *args, **kwargs): + if self.use_naive_forward(): return self.module(*args, **kwargs) else: return func(self, *args, **kwargs) @@ -237,7 +240,6 @@ def prepare_run( prompt=prompt, use_resolution_binning=input_config.use_resolution_binning, num_inference_steps=steps, - output_type="latent", generator=torch.Generator(device="cuda").manual_seed(42), ) get_runtime_state().runtime_config.warmup_steps = warmup_steps @@ -441,7 +443,7 @@ def is_dp_last_group(self): """Return True if in the last data parallel group, False otherwise. Also include parallel vae situation. """ - if get_runtime_state().runtime_config.use_parallel_vae: + if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward(): return get_world_group().rank == 0 else: return is_dp_last_group() diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index ce2c1c91..f430c836 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -72,7 +72,6 @@ def prepare_run( width=input_config.width, prompt=prompt, num_inference_steps=steps, - output_type="latent", max_sequence_length=input_config.max_sequence_length, generator=torch.Generator(device="cuda").manual_seed(42), ) diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index d4408a0f..3b6d471b 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -77,7 +77,6 @@ def prepare_run( width=input_config.width, prompt=prompt, num_inference_steps=steps, - output_type="latent", generator=torch.Generator(device="cuda").manual_seed(42), ) get_runtime_state().runtime_config.warmup_steps = warmup_steps