Skip to content

Commit

Permalink
Fix compatibility issue between parallel vae and naive forward; Enabl…
Browse files Browse the repository at this point in the history
…e warmup for vae (xdit-project#300)
  • Loading branch information
gty111 authored and feifeibear committed Oct 25, 2024
1 parent b9f10ac commit c63d09c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
20 changes: 11 additions & 9 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
if scheduler is not None:
pipeline.scheduler = self._convert_scheduler(scheduler)

if vae is not None and engine_config.runtime_config.use_parallel_vae:
if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward():
pipeline.vae = self._convert_vae(vae)

super().__init__(module=pipeline)
Expand Down Expand Up @@ -167,17 +167,20 @@ def data_parallel_fn(self, *args, **kwargs):

return data_parallel_fn

@staticmethod
def check_to_use_naive_forward(func):
@wraps(func)
def check_naive_forward_fn(self, *args, **kwargs):
if (
def use_naive_forward(self):
return (
get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
):
)

@staticmethod
def check_to_use_naive_forward(func):
@wraps(func)
def check_naive_forward_fn(self, *args, **kwargs):
if self.use_naive_forward():
return self.module(*args, **kwargs)
else:
return func(self, *args, **kwargs)
Expand Down Expand Up @@ -237,7 +240,6 @@ def prepare_run(
prompt=prompt,
use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down Expand Up @@ -441,7 +443,7 @@ def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
"""
if get_runtime_state().runtime_config.use_parallel_vae:
if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward():
return get_world_group().rank == 0
else:
return is_dp_last_group()
Expand Down
1 change: 0 additions & 1 deletion xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def prepare_run(
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
output_type="latent",
max_sequence_length=input_config.max_sequence_length,
generator=torch.Generator(device="cuda").manual_seed(42),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def prepare_run(
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down

0 comments on commit c63d09c

Please sign in to comment.