Skip to content

Commit

Permalink
refactor:let irecv emit after isend (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
taozhiwei authored Aug 8, 2024
1 parent 93721f1 commit f6fe00d
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 136 deletions.
94 changes: 55 additions & 39 deletions xfuser/model_executor/pipelines/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_sequence_parallel_world_size,
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage
is_pipeline_last_stage,
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -172,7 +172,9 @@ def __call__(
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
height, width = self.image_processor.classify_height_width_bin(
height, width, ratios=aspect_ratio_bin
)

self.check_inputs(
prompt,
Expand Down Expand Up @@ -200,15 +202,15 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

#! ---------------------------------------- ADDED BELOW ----------------------------------------
#* set runtime state input parameters
#! ---------------------------------------- ADDED BELOW ----------------------------------------
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
batch_size=batch_size,
num_inference_steps=num_inference_steps,
)
#! ---------------------------------------- ADDED ABOVE ----------------------------------------
#! ---------------------------------------- ADDED ABOVE ----------------------------------------

# 3. Encode input prompt
(
Expand All @@ -230,7 +232,7 @@ def __call__(
max_sequence_length=max_sequence_length,
)

#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
# * dealing with cfg degree
if do_classifier_free_guidance:
(
Expand All @@ -240,14 +242,14 @@ def __call__(
negative_prompt_embeds,
prompt_embeds,
negative_prompt_attention_mask,
prompt_attention_mask
prompt_attention_mask,
)

#! ORIGIN
# if do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
Expand Down Expand Up @@ -282,7 +284,7 @@ def __call__(
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
if (
do_classifier_free_guidance
and get_classifier_free_guidance_world_size() == 1
Expand All @@ -294,13 +296,15 @@ def __call__(
# if do_classifier_free_guidance:
# resolution = torch.cat([resolution, resolution], dim=0)
# aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
#! ---------------------------------------- MODIFIED BELOW ----------------------------------------
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps

with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -351,16 +355,20 @@ def __call__(
callback_steps=callback_steps,
sync_only=True,
)
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------
#! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

# 8. Decode latents (only rank 0)
#! ---------------------------------------- ADD BELOW ----------------------------------------
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_dp_last_rank():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
)
else:
image = latents

Expand All @@ -374,10 +382,11 @@ def __call__(
return (image,)

return ImagePipelineOutput(images=image)
#! ---------------------------------------- ADD BELOW ----------------------------------------
#! ---------------------------------------- ADD BELOW ----------------------------------------
else:
return None
#! ---------------------------------------- ADD ABOVE ----------------------------------------

#! ---------------------------------------- ADD ABOVE ----------------------------------------

def _scheduler_step(
self,
Expand Down Expand Up @@ -448,18 +457,15 @@ def _sync_pipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if (
sync_only
and is_pipeline_last_stage()
and i == len(timesteps) - 1
):
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)

if (sync_only and
get_sequence_parallel_world_size() > 1 and
is_pipeline_last_stage()
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
Expand All @@ -469,9 +475,10 @@ def _sync_pipeline(
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]:
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1],
:
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
Expand Down Expand Up @@ -524,10 +531,6 @@ def _async_pipeline(
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()
patch_latents[patch_idx] = self._backbone_forward(
latents=patch_latents[patch_idx],
prompt_embeds=prompt_embeds,
Expand All @@ -548,6 +551,14 @@ def _async_pipeline(
else:
get_pp_group().pipeline_isend(patch_latents[patch_idx])

if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()

get_runtime_state().next_patch()

if i == len(timesteps) - 1 or (
Expand All @@ -570,15 +581,20 @@ def _async_pipeline(
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]:
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1],
:
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
Expand Down Expand Up @@ -646,4 +662,4 @@ def _backbone_forward(
else:
latents = noise_pred

return latents
return latents
71 changes: 43 additions & 28 deletions xfuser/model_executor/pipelines/pipeline_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_sequence_parallel_world_size,
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage
is_pipeline_last_stage,
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -172,7 +172,9 @@ def __call__(
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
height, width = self.image_processor.classify_height_width_bin(
height, width, ratios=aspect_ratio_bin
)

self.check_inputs(
prompt,
Expand Down Expand Up @@ -201,7 +203,7 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

#* set runtime state input parameters
# * set runtime state input parameters
get_runtime_state().set_input_parameters(
height=height,
width=width,
Expand Down Expand Up @@ -238,7 +240,7 @@ def __call__(
negative_prompt_embeds,
prompt_embeds,
negative_prompt_attention_mask,
prompt_attention_mask
prompt_attention_mask,
)

# 4. Prepare timesteps
Expand Down Expand Up @@ -266,7 +268,9 @@ def __call__(
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}

# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
num_warmup_steps = max(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps

with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand Down Expand Up @@ -318,12 +322,16 @@ def __call__(
sync_only=True,
)

#* 8. Decode latents (only the last rank in a dp group)
# * 8. Decode latents (only the last rank in a dp group)
if is_dp_last_rank():
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
)
else:
image = latents

Expand Down Expand Up @@ -409,18 +417,15 @@ def _sync_pipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if (
sync_only
and is_pipeline_last_stage()
and i == len(timesteps) - 1
):
if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1:
pass
elif get_pipeline_parallel_world_size() > 1:
get_pp_group().pipeline_send(latents)

if (sync_only and
get_sequence_parallel_world_size() > 1 and
is_pipeline_last_stage()
if (
sync_only
and get_sequence_parallel_world_size() > 1
and is_pipeline_last_stage()
):
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
Expand All @@ -430,9 +435,10 @@ def _sync_pipeline(
sp_latents_list[sp_patch_idx][
:,
:,
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]:
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1],
:
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
Expand Down Expand Up @@ -485,10 +491,6 @@ def _async_pipeline(
patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data(
idx=patch_idx
)
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()
patch_latents[patch_idx] = self._backbone_forward(
latents=patch_latents[patch_idx],
prompt_embeds=prompt_embeds,
Expand All @@ -509,6 +511,14 @@ def _async_pipeline(
else:
get_pp_group().pipeline_isend(patch_latents[patch_idx])

if is_pipeline_first_stage() and i == 0:
pass
else:
if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1:
pass
else:
get_pp_group().recv_next()

get_runtime_state().next_patch()

if i == len(timesteps) - 1 or (
Expand All @@ -531,15 +541,20 @@ def _async_pipeline(
latents = torch.cat(patch_latents, dim=2)
if get_sequence_parallel_world_size() > 1:
sp_degree = get_sequence_parallel_world_size()
sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True)
sp_latents_list = get_sp_group().all_gather(
latents, separate_tensors=True
)
latents_list = []
for pp_patch_idx in range(get_runtime_state().num_pipeline_patch):
latents_list += [
sp_latents_list[sp_patch_idx][
...,
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]:
get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1],
:
get_runtime_state()
.pp_patches_start_idx_local[
pp_patch_idx
] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
]
Expand Down Expand Up @@ -607,4 +622,4 @@ def _backbone_forward(
else:
latents = noise_pred

return latents
return latents
Loading

0 comments on commit f6fe00d

Please sign in to comment.