diff --git a/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py b/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py index 0fdc0919..47388541 100644 --- a/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py +++ b/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py @@ -24,7 +24,7 @@ get_sequence_parallel_world_size, get_sp_group, is_pipeline_first_stage, - is_pipeline_last_stage + is_pipeline_last_stage, ) from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -172,7 +172,9 @@ def __call__( else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width - height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + height, width = self.image_processor.classify_height_width_bin( + height, width, ratios=aspect_ratio_bin + ) self.check_inputs( prompt, @@ -200,15 +202,15 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 -#! ---------------------------------------- ADDED BELOW ---------------------------------------- - #* set runtime state input parameters + #! ---------------------------------------- ADDED BELOW ---------------------------------------- + # * set runtime state input parameters get_runtime_state().set_input_parameters( height=height, width=width, batch_size=batch_size, num_inference_steps=num_inference_steps, ) -#! ---------------------------------------- ADDED ABOVE ---------------------------------------- + #! ---------------------------------------- ADDED ABOVE ---------------------------------------- # 3. Encode input prompt ( @@ -230,7 +232,7 @@ def __call__( max_sequence_length=max_sequence_length, ) -#! ---------------------------------------- MODIFIED BELOW ---------------------------------------- + #! ---------------------------------------- MODIFIED BELOW ---------------------------------------- # * dealing with cfg degree if do_classifier_free_guidance: ( @@ -240,14 +242,14 @@ def __call__( negative_prompt_embeds, prompt_embeds, negative_prompt_attention_mask, - prompt_attention_mask + prompt_attention_mask, ) #! ORIGIN # if do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) -#! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- + #! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -282,7 +284,7 @@ def __call__( resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) -#! ---------------------------------------- MODIFIED BELOW ---------------------------------------- + #! ---------------------------------------- MODIFIED BELOW ---------------------------------------- if ( do_classifier_free_guidance and get_classifier_free_guidance_world_size() == 1 @@ -294,13 +296,15 @@ def __call__( # if do_classifier_free_guidance: # resolution = torch.cat([resolution, resolution], dim=0) # aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) -#! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- + #! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) -#! ---------------------------------------- MODIFIED BELOW ---------------------------------------- + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + #! ---------------------------------------- MODIFIED BELOW ---------------------------------------- num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -351,16 +355,20 @@ def __call__( callback_steps=callback_steps, sync_only=True, ) -#! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- + #! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- # 8. Decode latents (only rank 0) -#! ---------------------------------------- ADD BELOW ---------------------------------------- + #! ---------------------------------------- ADD BELOW ---------------------------------------- if is_dp_last_rank(): -#! ---------------------------------------- ADD ABOVE ---------------------------------------- + #! ---------------------------------------- ADD ABOVE ---------------------------------------- if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] if use_resolution_binning: - image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + image = self.image_processor.resize_and_crop_tensor( + image, orig_width, orig_height + ) else: image = latents @@ -374,10 +382,11 @@ def __call__( return (image,) return ImagePipelineOutput(images=image) -#! ---------------------------------------- ADD BELOW ---------------------------------------- + #! ---------------------------------------- ADD BELOW ---------------------------------------- else: return None -#! ---------------------------------------- ADD ABOVE ---------------------------------------- + + #! ---------------------------------------- ADD ABOVE ---------------------------------------- def _scheduler_step( self, @@ -448,18 +457,15 @@ def _sync_pipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - if ( - sync_only - and is_pipeline_last_stage() - and i == len(timesteps) - 1 - ): + if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1: pass elif get_pipeline_parallel_world_size() > 1: get_pp_group().pipeline_send(latents) - if (sync_only and - get_sequence_parallel_world_size() > 1 and - is_pipeline_last_stage() + if ( + sync_only + and get_sequence_parallel_world_size() > 1 + and is_pipeline_last_stage() ): sp_degree = get_sequence_parallel_world_size() sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) @@ -469,9 +475,10 @@ def _sync_pipeline( sp_latents_list[sp_patch_idx][ :, :, - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -524,10 +531,6 @@ def _async_pipeline( patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data( idx=patch_idx ) - if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: - pass - else: - get_pp_group().recv_next() patch_latents[patch_idx] = self._backbone_forward( latents=patch_latents[patch_idx], prompt_embeds=prompt_embeds, @@ -548,6 +551,14 @@ def _async_pipeline( else: get_pp_group().pipeline_isend(patch_latents[patch_idx]) + if is_pipeline_first_stage() and i == 0: + pass + else: + if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: + pass + else: + get_pp_group().recv_next() + get_runtime_state().next_patch() if i == len(timesteps) - 1 or ( @@ -570,15 +581,20 @@ def _async_pipeline( latents = torch.cat(patch_latents, dim=2) if get_sequence_parallel_world_size() > 1: sp_degree = get_sequence_parallel_world_size() - sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) + sp_latents_list = get_sp_group().all_gather( + latents, separate_tensors=True + ) latents_list = [] for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): latents_list += [ sp_latents_list[sp_patch_idx][ ..., - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[ + pp_patch_idx + ] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -646,4 +662,4 @@ def _backbone_forward( else: latents = noise_pred - return latents \ No newline at end of file + return latents diff --git a/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py b/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py index a17ceacb..091e9893 100644 --- a/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py +++ b/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py @@ -24,7 +24,7 @@ get_sequence_parallel_world_size, get_sp_group, is_pipeline_first_stage, - is_pipeline_last_stage + is_pipeline_last_stage, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -172,7 +172,9 @@ def __call__( else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width - height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + height, width = self.image_processor.classify_height_width_bin( + height, width, ratios=aspect_ratio_bin + ) self.check_inputs( prompt, @@ -201,7 +203,7 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - #* set runtime state input parameters + # * set runtime state input parameters get_runtime_state().set_input_parameters( height=height, width=width, @@ -238,7 +240,7 @@ def __call__( negative_prompt_embeds, prompt_embeds, negative_prompt_attention_mask, - prompt_attention_mask + prompt_attention_mask, ) # 4. Prepare timesteps @@ -266,7 +268,9 @@ def __call__( added_cond_kwargs = {"resolution": None, "aspect_ratio": None} # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) num_pipeline_warmup_steps = get_runtime_state().runtime_config.warmup_steps with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -318,12 +322,16 @@ def __call__( sync_only=True, ) - #* 8. Decode latents (only the last rank in a dp group) + # * 8. Decode latents (only the last rank in a dp group) if is_dp_last_rank(): if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] if use_resolution_binning: - image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + image = self.image_processor.resize_and_crop_tensor( + image, orig_width, orig_height + ) else: image = latents @@ -409,18 +417,15 @@ def _sync_pipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - if ( - sync_only - and is_pipeline_last_stage() - and i == len(timesteps) - 1 - ): + if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1: pass elif get_pipeline_parallel_world_size() > 1: get_pp_group().pipeline_send(latents) - if (sync_only and - get_sequence_parallel_world_size() > 1 and - is_pipeline_last_stage() + if ( + sync_only + and get_sequence_parallel_world_size() > 1 + and is_pipeline_last_stage() ): sp_degree = get_sequence_parallel_world_size() sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) @@ -430,9 +435,10 @@ def _sync_pipeline( sp_latents_list[sp_patch_idx][ :, :, - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -485,10 +491,6 @@ def _async_pipeline( patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data( idx=patch_idx ) - if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: - pass - else: - get_pp_group().recv_next() patch_latents[patch_idx] = self._backbone_forward( latents=patch_latents[patch_idx], prompt_embeds=prompt_embeds, @@ -509,6 +511,14 @@ def _async_pipeline( else: get_pp_group().pipeline_isend(patch_latents[patch_idx]) + if is_pipeline_first_stage() and i == 0: + pass + else: + if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: + pass + else: + get_pp_group().recv_next() + get_runtime_state().next_patch() if i == len(timesteps) - 1 or ( @@ -531,15 +541,20 @@ def _async_pipeline( latents = torch.cat(patch_latents, dim=2) if get_sequence_parallel_world_size() > 1: sp_degree = get_sequence_parallel_world_size() - sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) + sp_latents_list = get_sp_group().all_gather( + latents, separate_tensors=True + ) latents_list = [] for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): latents_list += [ sp_latents_list[sp_patch_idx][ ..., - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[ + pp_patch_idx + ] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -607,4 +622,4 @@ def _backbone_forward( else: latents = noise_pred - return latents \ No newline at end of file + return latents diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 180a1b5e..4a1ea443 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -19,10 +19,10 @@ from diffusers import StableDiffusion3Pipeline from diffusers.utils import is_torch_xla_available from diffusers.pipelines.stable_diffusion_3.pipeline_output import ( - StableDiffusion3PipelineOutput + StableDiffusion3PipelineOutput, ) from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( - retrieve_timesteps + retrieve_timesteps, ) from xfuser.config import EngineConfig, InputConfig @@ -37,7 +37,7 @@ get_pp_group, get_sequence_parallel_world_size, get_sp_group, - is_dp_last_rank + is_dp_last_rank, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -65,12 +65,10 @@ def from_pretrained( ) return cls(pipeline, engine_config) - def prepare_run(self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1): - prompt = ( - [""] * input_config.batch_size - if input_config.batch_size > 1 - else "" - ) + def prepare_run( + self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1 + ): + prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else "" warmup_steps = get_runtime_state().runtime_config.warmup_steps get_runtime_state().runtime_config.warmup_steps = sync_steps self.__call__( @@ -266,15 +264,15 @@ def __call__( device = self._execution_device -#! ---------------------------------------- ADDED BELOW ---------------------------------------- - #* set runtime state input parameters + #! ---------------------------------------- ADDED BELOW ---------------------------------------- + # * set runtime state input parameters get_runtime_state().set_input_parameters( height=height, width=width, batch_size=batch_size, num_inference_steps=num_inference_steps, ) -#! ---------------------------------------- ADDED ABOVE ---------------------------------------- + #! ---------------------------------------- ADDED ABOVE ---------------------------------------- ( prompt_embeds, @@ -299,7 +297,7 @@ def __call__( ) if self.do_classifier_free_guidance: -#! ---------------------------------------- MODIFIED BELOW ---------------------------------------- + #! ---------------------------------------- MODIFIED BELOW ---------------------------------------- ( prompt_embeds, pooled_prompt_embeds, @@ -313,11 +311,15 @@ def __call__( #! ORIGIN # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) -#! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- + #! ---------------------------------------- MODIFIED ABOVE ---------------------------------------- # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) self._num_timesteps = len(timesteps) # 5. Prepare latent variables @@ -375,13 +377,15 @@ def __call__( sync_only=True, ) - #* 8. Decode latents (only the last rank in a dp group) + # * 8. Decode latents (only the last rank in a dp group) if is_dp_last_rank(): if output_type == "latent": image = latents else: - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) @@ -398,7 +402,9 @@ def __call__( def set_sd3_extra_comm_tensor(self, prompt_embeds: torch.Tensor): prompt_embeds_shape = prompt_embeds.shape - encoder_hidden_states_shape = prompt_embeds_shape[:-1] + (self.transformer.config.caption_projection_dim,) + encoder_hidden_states_shape = prompt_embeds_shape[:-1] + ( + self.transformer.config.caption_projection_dim, + ) self._set_extra_comm_tensor_for_pipeline( [ ( @@ -441,13 +447,17 @@ def _sync_pipeline( else: latents = get_pp_group().pipeline_recv() if not is_pipeline_first_stage(): - encoder_hidden_states = get_pp_group().pipeline_recv(0, "encoder_hidden_states") - + encoder_hidden_states = get_pp_group().pipeline_recv( + 0, "encoder_hidden_states" + ) latents, encoder_hidden_states = self._backbone_forward( latents=latents, - encoder_hidden_states=prompt_embeds - if is_pipeline_first_stage() else encoder_hidden_states, + encoder_hidden_states=( + prompt_embeds + if is_pipeline_first_stage() + else encoder_hidden_states + ), pooled_prompt_embeds=pooled_prompt_embeds, t=t, ) @@ -469,30 +479,31 @@ def _sync_pipeline( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() - if ( - sync_only - and is_pipeline_last_stage() - and i == len(timesteps) - 1 - ): + if sync_only and is_pipeline_last_stage() and i == len(timesteps) - 1: pass elif get_pipeline_parallel_world_size() > 1: get_pp_group().pipeline_send(latents) if not is_pipeline_last_stage(): get_pp_group().pipeline_send(encoder_hidden_states) - if (sync_only and - get_sequence_parallel_world_size() > 1 and - is_pipeline_last_stage() + if ( + sync_only + and get_sequence_parallel_world_size() > 1 + and is_pipeline_last_stage() ): sp_degree = get_sequence_parallel_world_size() sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) @@ -502,9 +513,10 @@ def _sync_pipeline( sp_latents_list[sp_patch_idx][ :, :, - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -528,11 +540,17 @@ def _init_sd3_async_pipeline( if num_pipeline_warmup_steps > 0 else latents ) - patch_latents = list(latents.split(get_runtime_state().pp_patches_height, dim=2)) + patch_latents = list( + latents.split(get_runtime_state().pp_patches_height, dim=2) + ) elif is_pipeline_last_stage(): - patch_latents = list(latents.split(get_runtime_state().pp_patches_height, dim=2)) + patch_latents = list( + latents.split(get_runtime_state().pp_patches_height, dim=2) + ) else: - patch_latents = [None for _ in range(get_runtime_state().num_pipeline_patch)] + patch_latents = [ + None for _ in range(get_runtime_state().num_pipeline_patch) + ] recv_timesteps = ( num_timesteps - 1 if is_pipeline_first_stage() else num_timesteps @@ -596,30 +614,26 @@ def _async_pipeline( first_async_recv = False if not is_pipeline_first_stage() and patch_idx == 0: - last_encoder_hidden_states = get_pp_group().get_pipeline_recv_data( - idx=patch_idx, name="encoder_hidden_states" + last_encoder_hidden_states = ( + get_pp_group().get_pipeline_recv_data( + idx=patch_idx, name="encoder_hidden_states" + ) ) patch_latents[patch_idx] = get_pp_group().get_pipeline_recv_data( idx=patch_idx ) - if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: - pass - elif is_pipeline_first_stage(): - get_pp_group().recv_next() - else: - # recv encoder_hidden_state - if patch_idx == num_pipeline_patch - 1: - get_pp_group().recv_next() - # recv latents - get_pp_group().recv_next() - patch_latents[patch_idx], next_encoder_hidden_states = self._backbone_forward( - latents=patch_latents[patch_idx], - encoder_hidden_states=prompt_embeds - if is_pipeline_first_stage() - else last_encoder_hidden_states, - pooled_prompt_embeds=pooled_prompt_embeds, - t=t, + patch_latents[patch_idx], next_encoder_hidden_states = ( + self._backbone_forward( + latents=patch_latents[patch_idx], + encoder_hidden_states=( + prompt_embeds + if is_pipeline_first_stage() + else last_encoder_hidden_states + ), + pooled_prompt_embeds=pooled_prompt_embeds, + t=t, + ) ) if is_pipeline_last_stage(): latents_dtype = patch_latents[patch_idx].dtype @@ -638,13 +652,20 @@ def _async_pipeline( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, ) if i != len(timesteps) - 1: @@ -654,9 +675,24 @@ def _async_pipeline( get_pp_group().pipeline_isend(next_encoder_hidden_states) get_pp_group().pipeline_isend(patch_latents[patch_idx]) + if is_pipeline_first_stage() and i == 0: + pass + else: + if i == len(timesteps) - 1 and patch_idx == num_pipeline_patch - 1: + pass + elif is_pipeline_first_stage(): + get_pp_group().recv_next() + else: + # recv encoder_hidden_state + if patch_idx == num_pipeline_patch - 1: + get_pp_group().recv_next() + # recv latents + get_pp_group().recv_next() + get_runtime_state().next_patch() - if i == len(timesteps) - 1 or ((i + num_pipeline_warmup_steps + 1) > num_warmup_steps + if i == len(timesteps) - 1 or ( + (i + num_pipeline_warmup_steps + 1) > num_warmup_steps and (i + num_pipeline_warmup_steps + 1) % self.scheduler.order == 0 ): progress_bar.update() @@ -669,15 +705,20 @@ def _async_pipeline( latents = torch.cat(patch_latents, dim=2) if get_sequence_parallel_world_size() > 1: sp_degree = get_sequence_parallel_world_size() - sp_latents_list = get_sp_group().all_gather(latents, separate_tensors=True) + sp_latents_list = get_sp_group().all_gather( + latents, separate_tensors=True + ) latents_list = [] for pp_patch_idx in range(get_runtime_state().num_pipeline_patch): latents_list += [ sp_latents_list[sp_patch_idx][ ..., - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx]: - get_runtime_state().pp_patches_start_idx_local[pp_patch_idx+1], - : + get_runtime_state() + .pp_patches_start_idx_local[ + pp_patch_idx + ] : get_runtime_state() + .pp_patches_start_idx_local[pp_patch_idx + 1], + :, ] for sp_patch_idx in range(sp_degree) ] @@ -716,7 +757,9 @@ def _backbone_forward( noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather( noise_pred, separate_tensors=True ) - latents = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + latents = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) else: latents = noise_pred @@ -733,4 +776,4 @@ def _scheduler_step( t, latents, return_dict=False, - )[0] \ No newline at end of file + )[0]