Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable warm up for VAE #300

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
if scheduler is not None:
pipeline.scheduler = self._convert_scheduler(scheduler)

if vae is not None and engine_config.runtime_config.use_parallel_vae:
if vae is not None and engine_config.runtime_config.use_parallel_vae and not self.use_naive_forward():
pipeline.vae = self._convert_vae(vae)

super().__init__(module=pipeline)
Expand Down Expand Up @@ -167,17 +167,20 @@ def data_parallel_fn(self, *args, **kwargs):

return data_parallel_fn

@staticmethod
def check_to_use_naive_forward(func):
@wraps(func)
def check_naive_forward_fn(self, *args, **kwargs):
if (
def use_naive_forward(self):
return (
get_pipeline_parallel_world_size() == 1
and get_classifier_free_guidance_world_size() == 1
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
):
)

@staticmethod
def check_to_use_naive_forward(func):
@wraps(func)
def check_naive_forward_fn(self, *args, **kwargs):
if self.use_naive_forward():
return self.module(*args, **kwargs)
else:
return func(self, *args, **kwargs)
Expand Down Expand Up @@ -237,7 +240,6 @@ def prepare_run(
prompt=prompt,
use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down Expand Up @@ -441,7 +443,7 @@ def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
"""
if get_runtime_state().runtime_config.use_parallel_vae:
if get_runtime_state().runtime_config.use_parallel_vae and not self.use_naive_forward():
return get_world_group().rank == 0
else:
return is_dp_last_group()
Expand Down
1 change: 0 additions & 1 deletion xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def prepare_run(
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
output_type="latent",
max_sequence_length=input_config.max_sequence_length,
generator=torch.Generator(device="cuda").manual_seed(42),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def prepare_run(
width=input_config.width,
prompt=prompt,
num_inference_steps=steps,
output_type="latent",
generator=torch.Generator(device="cuda").manual_seed(42),
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps
Expand Down
Loading