From e559ae8d2353d4fd1aafd9e0a928b600618ffc64 Mon Sep 17 00:00:00 2001
From: Edna <88869424+Ednaordinary@users.noreply.github.com>
Date: Sun, 29 Dec 2024 06:22:44 -0700
Subject: [PATCH] Update transformer_hunyuan_video.py
---
.../transformers/transformer_hunyuan_video.py | 141 ++++++++++++++++++
1 file changed, 141 insertions(+)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index e3f24d97f3fa..919fd9615d3f 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -136,6 +136,103 @@ def __call__(
return hidden_states, encoder_hidden_states
+class FusedHunyuanVideoAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedHunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ # 1. QKV projections
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ query = torch.cat(
+ [
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
+ query[:, :, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=2,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
+ key[:, :, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=2,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # 4. Encoder condition QKV projection and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=2)
+ key = torch.cat([key, encoder_key], dim=2)
+ value = torch.cat([value, encoder_value], dim=2)
+
+ # 5. Attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
class HunyuanVideoPatchEmbed(nn.Module):
def __init__(
@@ -214,6 +311,10 @@ def forward(
)
gate_msa, gate_mlp = self.norm_out(temb)
+ # QKV fusion fix
+ if isinstance(attn_output, tuple):
+ attn_output = attn_output[0]
+
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
@@ -604,6 +705,46 @@ def __init__(
self.gradient_checkpointing = False
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanVideoAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedHunyuanVideoAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]: