From 69a36d36beb1cbbb2092254b054cd4d864fd9168 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Thu, 12 Sep 2024 22:14:09 +0800 Subject: [PATCH] Revert "CogVideoX support with USP (#261)" This reverts commit 9484590ed20d5cc1aa08f43dcca304811c3ecd9d. --- setup.py | 40 ++-- xfuser/core/distributed/runtime_state.py | 13 +- .../long_ctx_attention/hybrid/attn_layer.py | 6 +- .../layers/attention_processor.py | 221 ++---------------- .../transformers/cogvideox_transformer_3d.py | 46 ++-- .../pipelines/pipeline_cogvideox.py | 48 +--- .../model_executor/pipelines/pipeline_flux.py | 2 +- .../pipelines/pipeline_hunyuandit.py | 10 +- 8 files changed, 84 insertions(+), 302 deletions(-) diff --git a/setup.py b/setup.py index d04d473a..3628a2f0 100644 --- a/setup.py +++ b/setup.py @@ -12,34 +12,40 @@ def get_cuda_version(): except Exception as e: return 'no_cuda' +def get_install_requires(cuda_version): + if cuda_version == 'cu124': + sys.stderr.write("WARNING: Manual installation required for CUDA 12.4 specific PyTorch version.\n") + sys.stderr.write("Please install PyTorch for CUDA 12.4 using the following command:\n") + sys.stderr.write("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n") + + return [ + "torch==2.3.0", + "diffusers>=0.30.0", + "transformers>=4.39.1", + "sentencepiece>=0.1.99", + "accelerate==0.33.0", + "beautifulsoup4>=4.12.3", + "distvae", + "yunchang==0.3", + "flash_attn>=2.6.3", + "pytest", + "flask", + ] + if __name__ == "__main__": with open("README.md", "r") as f: long_description = f.read() fp = open("xfuser/__version__.py", "r").read() version = eval(fp.strip().split()[-1]) + cuda_version = get_cuda_version() + setup( name="xfuser", author="xDiT Team", author_email="fangjiarui123@gmail.com", packages=find_packages(), - install_requires=[ - "torch>=2.3.0", - "accelerate==0.33.0", - "diffusers>=0.30.0", - "transformers>=4.39.1", - "sentencepiece>=0.1.99", - "beautifulsoup4>=4.12.3", - "distvae", - "yunchang==0.3", - "pytest", - "flask", - ], - extras_require={ - "all": [ - "flash_attn>=2.6.3", - ], - }, + install_requires=get_install_requires(cuda_version), url="https://github.com/xdit-project/xDiT.", description="xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters", long_description=long_description, diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 5f4085ab..a2c9c631 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -136,8 +136,7 @@ def set_input_parameters( self.input_config.seed = seed set_random_seed(seed) if ( - not self.ready - or (height and self.input_config.height != height) + (height and self.input_config.height != height) or (width and self.input_config.width != width) or (batch_size and self.input_config.batch_size != batch_size) or not self.ready @@ -164,8 +163,7 @@ def set_video_input_parameters( self.input_config.seed = seed set_random_seed(seed) if ( - not self.ready - or (height and self.input_config.height != height) + (height and self.input_config.height != height) or (width and self.input_config.width != width) or (num_frames and self.input_config.num_frames != num_frames) or (batch_size and self.input_config.batch_size != batch_size) @@ -363,6 +361,7 @@ def _calc_patches_metadata(self): self.pp_patches_token_num = pp_patches_token_num def _calc_cogvideox_patches_metadata(self): + num_sp_patches = get_sequence_parallel_world_size() sp_patch_idx = get_sequence_parallel_rank() patch_size = self.backbone_patch_size @@ -451,9 +450,11 @@ def _calc_cogvideox_patches_metadata(self): pp_patches_token_start_end_idx_global = [ [ (latents_width // patch_size) - * (start_idx // patch_size), + * (start_idx // patch_size) + * latents_frames, (latents_width // patch_size) - * (end_idx // patch_size), + * (end_idx // patch_size) + * latents_frames, ] for start_idx, end_idx in pp_patches_start_end_idx_global ] diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index 60d71264..e30beb79 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -228,9 +228,6 @@ def forward( key: Tensor, value: Tensor, *, - joint_tensor_query, - joint_tensor_key, - joint_tensor_value, dropout_p=0.0, softmax_scale=None, causal=False, @@ -238,6 +235,9 @@ def forward( alibi_slopes=None, deterministic=False, return_attn_probs=False, + joint_tensor_query=None, + joint_tensor_key=None, + joint_tensor_value=None, joint_strategy="front", ) -> Tensor: """forward diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 1affaaa2..c6481ebd 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -1,5 +1,5 @@ import inspect -from typing import Optional, Union, Tuple +from typing import Optional import torch from torch import nn @@ -12,12 +12,9 @@ JointAttnProcessor2_0, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0, + apply_rope, HunyuanAttnProcessor2_0, ) -try: - from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0 -except ImportError: - CogVideoXAttnProcessor2_0 = None from xfuser.core.distributed import ( get_sequence_parallel_world_size, @@ -45,62 +42,11 @@ def is_v100(): device_name = torch.cuda.get_device_name(torch.cuda.current_device()) return "V100" in device_name - def torch_compile_disable_if_v100(func): if is_v100(): return torch.compiler.disable(func) return func - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], - use_real: bool = True, - use_real_unbind_dim: int = -1, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings - to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are - reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting - tensors contain rotary embeddings and are returned as real tensors. - - Args: - x (`torch.Tensor`): - Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply - freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. - """ - if use_real: - cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - if use_real_unbind_dim == -1: - # Used for flux, cogvideox, hunyuan-dit - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - elif use_real_unbind_dim == -2: - # Used for Stable Audio - x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] - x_rotated = torch.cat([-x_imag, x_real], dim=-1) - else: - raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - - return out - else: - # used for lumina - x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) - - return x_out.type_as(x) - - class xFuserAttentionBaseWrapper(xFuserLayerBaseWrapper): def __init__( self, @@ -702,8 +648,11 @@ def __call__( value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + # YiYi to-do: update uising apply_rotary_emb + # from ..embeddings import apply_rotary_emb + # query = apply_rotary_emb(query, image_rotary_emb) + # key = apply_rotary_emb(key, image_rotary_emb) + query, key = apply_rope(query, key, image_rotary_emb) #! ---------------------------------------- KV CACHE ---------------------------------------- if not self.use_long_ctx_attn_kvcache: @@ -875,8 +824,11 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + # YiYi to-do: update uising apply_rotary_emb + # from ..embeddings import apply_rotary_emb + # query = apply_rotary_emb(query, image_rotary_emb) + # key = apply_rotary_emb(key, image_rotary_emb) + query, key = apply_rope(query, key, image_rotary_emb) #! ---------------------------------------- KV CACHE ---------------------------------------- if not self.use_long_ctx_attn_kvcache: @@ -980,6 +932,8 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, latte_temporal_attention: Optional[bool] = False, ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -1136,150 +1090,3 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - - -if CogVideoXAttnProcessor2_0 is not None: - - @xFuserAttentionProcessorRegister.register(CogVideoXAttnProcessor2_0) - class xFuserCogVideoXAttnProcessor2_0(CogVideoXAttnProcessor2_0): - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - super().__init__() - use_long_ctx_attn_kvcache = True - self.use_long_ctx_attn_kvcache = ( - HAS_LONG_CTX_ATTN - and use_long_ctx_attn_kvcache - and get_sequence_parallel_world_size() > 1 - ) - if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: - from xfuser.core.long_ctx_attention import ( - xFuserLongContextAttention, - xFuserUlyssesAttention, - ) - - if HAS_FLASH_ATTN: - self.hybrid_seq_parallel_attn = xFuserLongContextAttention( - use_kv_cache=self.use_long_ctx_attn_kvcache - ) - else: - self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( - use_fa=False, - use_kv_cache=self.use_long_ctx_attn_kvcache, - ) - else: - self.hybrid_seq_parallel_attn = None - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - #! ---------------------------------------- KV CACHE ---------------------------------------- - if not self.use_long_ctx_attn_kvcache: - key, value = get_cache_manager().update_and_get_kv_cache( - new_kv=[key, value], - layer=attn, - slice_dim=2, - layer_type="attn", - ) - #! ---------------------------------------- KV CACHE ---------------------------------------- - - #! ---------------------------------------- ATTENTION ---------------------------------------- - if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - hidden_states = self.hybrid_seq_parallel_attn( - attn, - query, - key, - value, - dropout_p=0.0, - causal=False, - joint_strategy="none", - ) - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) - else: - if HAS_FLASH_ATTN: - from flash_attn import flash_attn_func - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - hidden_states = flash_attn_func( - query, key, value, dropout_p=0.0, causal=False - ) - hidden_states = hidden_states.reshape( - batch_size, -1, attn.heads * head_dim - ) - - else: - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - - #! ORIGIN - # hidden_states = F.scaled_dot_product_attention( - # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - # ) - # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - #! ---------------------------------------- ATTENTION ---------------------------------------- - - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states diff --git a/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py b/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py index 48793152..9dec9f06 100644 --- a/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py +++ b/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py @@ -52,11 +52,9 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape - # 1. Time embedding timesteps = timestep t_emb = self.time_proj(timesteps) @@ -69,13 +67,31 @@ def forward( # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + # print(f"device: {torch.distributed.get_rank()}: hidden_states: {hidden_states.shape}") + + # 3. Position embedding + seq_length = height * width * num_frames // (self.config.patch_size**2) * get_sequence_parallel_world_size() + + pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] + # print(f"device: {torch.distributed.get_rank()}: pos_embeds: {pos_embeds.shape}") + txt_pos_embeds = pos_embeds[:, : self.config.max_text_seq_length] + # print(f"device: {torch.distributed.get_rank()}: txt_pos_embeds: {txt_pos_embeds.shape}") + img_pos_embeds = pos_embeds[:, self.config.max_text_seq_length \ + + get_runtime_state().pp_patches_token_start_end_idx_global[0][0] + : self.config.max_text_seq_length + \ + get_runtime_state().pp_patches_token_start_end_idx_global[0][1]] + # print(f"device: {torch.distributed.get_rank()}: get_runtime_state().pp_patches_token_start_end_idx_global: {get_runtime_state().pp_patches_token_start_end_idx_global}") + # print(f"device: {torch.distributed.get_rank()}: img_pos_embeds: {img_pos_embeds.shape}") + pos_embeds = torch.cat([txt_pos_embeds, img_pos_embeds], dim=1) + # print(f"device: {torch.distributed.get_rank()}: pos_embeds: {pos_embeds.shape}") + + hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] + encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] + hidden_states = hidden_states[:, self.config.max_text_seq_length :] - # 3. Transformer blocks + # 4. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: @@ -91,7 +107,6 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, emb, - image_rotary_emb, **ckpt_kwargs, ) else: @@ -99,27 +114,20 @@ def custom_forward(*inputs): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, - image_rotary_emb=image_rotary_emb, ) - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] + hidden_states = self.norm_final(hidden_states) - # 4. Final block + # 5. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) - # 5. Unpatchify + # 6. Unpatchify p = self.config.patch_size + # print(f"device: {torch.distributed.get_rank()}: hidden_states: {hidden_states.shape}") output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 4ca79c07..cd384311 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -8,8 +8,9 @@ CogVideoXPipelineOutput, retrieve_timesteps, ) -from diffusers.schedulers import CogVideoXDPMScheduler +from diffusers.schedulers import CogVideoXDPMScheduler, CogVideoXDDIMScheduler from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.utils import deprecate import math @@ -29,9 +30,6 @@ get_runtime_state, initialize_runtime_state, get_data_parallel_rank, - is_pipeline_first_stage, - is_pipeline_last_stage, - is_dp_last_group, ) from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper @@ -252,20 +250,13 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Create rotary embeds if required - image_rotary_emb = ( - self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) - if self.transformer.config.use_rotary_positional_embeddings - else None - ) - - # 8. Denoising loop + # 7. Denoising loop num_warmup_steps = max( len(timesteps) - num_inference_steps * self.scheduler.order, 0 ) - latents, image_rotary_emb = self._init_sync_pipeline(latents, image_rotary_emb, latents.size(1)) with self.progress_bar(total=num_inference_steps) as progress_bar: + latents = self._init_video_sync_pipeline(latents) # for DPM-solver++ old_pred_original_sample = None for i, t in enumerate(timesteps): @@ -287,7 +278,6 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, - image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() @@ -375,8 +365,6 @@ def __call__( ) else: video = latents - else: - video = [None for _ in range(batch_size)] # Offload all models self.maybe_free_model_hooks() @@ -386,34 +374,6 @@ def __call__( return CogVideoXPipelineOutput(frames=video) - def _init_sync_pipeline( - self, - latents: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - latents_frames: Optional[int] = None, - ): - latents = super()._init_video_sync_pipeline(latents) - if image_rotary_emb is not None: - assert latents_frames is not None - d = image_rotary_emb[0].shape[-1] - image_rotary_emb = ( - torch.cat( - [ - image_rotary_emb[0].reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx].reshape(-1, d) - for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global - ], - dim=0, - ), - torch.cat( - [ - image_rotary_emb[1].reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx].reshape(-1, d) - for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global - ], - dim=0, - ), - ) - return latents, image_rotary_emb - @property def interrupt(self): return self._interrupt diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 0c91c0e2..43751a94 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -389,7 +389,7 @@ def _sync_pipeline( 0, "encoder_hidden_states" ) - # handle guidance + # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.tensor([guidance_scale], device=self._execution_device) guidance = guidance.expand(latents.shape[0]) diff --git a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py index ebe87501..7d8c784d 100644 --- a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py +++ b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py @@ -493,7 +493,7 @@ def __call__( def _init_sync_pipeline(self, latents: torch.Tensor, image_rotary_emb): latents = super()._init_sync_pipeline(latents) - image_rotary_emb = ( + image_rotary_emb = [ torch.cat( [ image_rotary_emb[0][start_token_idx:end_token_idx, ...] @@ -508,7 +508,7 @@ def _init_sync_pipeline(self, latents: torch.Tensor, image_rotary_emb): ], dim=0, ), - ) + ] return latents, image_rotary_emb def _init_async_pipeline( @@ -540,7 +540,7 @@ def _sync_pipeline( prompt_attention_mask_2: torch.Tensor, add_time_ids: torch.Tensor, style: torch.Tensor, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], + image_rotary_emb: torch.FloatTensor, device: torch.device, guidance_scale: float, guidance_rescale: float, @@ -677,7 +677,7 @@ def _async_pipeline( prompt_attention_mask_2: torch.Tensor, add_time_ids: torch.Tensor, style: torch.Tensor, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], + image_rotary_emb: torch.FloatTensor, device: torch.device, guidance_scale: float, guidance_rescale: float, @@ -860,7 +860,7 @@ def _backbone_forward( prompt_attention_mask_2: torch.FloatTensor, add_time_ids: torch.Tensor, style: torch.Tensor, - image_rotary_emb: Tuple[torch.Tensor, torch.Tensor], + image_rotary_emb: torch.FloatTensor, t: Union[float, torch.Tensor], device: torch.device, guidance_scale: float,