diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py index 3bbc579dc..e2cfd49ec 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py @@ -20,9 +20,58 @@ class TemporalTransformer3DModel_OF(TemporalTransformer3DModel_OF_CLS): - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None): + def get_cameractrl_effect(self, hidden_states: torch.Tensor) : + # if no raw camera_Ctrl, return None + if self.raw_cameractrl_effect is None: + return 1.0 + # if raw_cameractrl is not a Tensor, return it (should be a float) + if type(self.raw_cameractrl_effect) != torch.Tensor: + return self.raw_cameractrl_effect + shape = hidden_states.shape + batch, channel, height, width = shape + # if temp_cameractrl already calculated, return it + if self.temp_cameractrl_effect != None: + # check if hidden_states batch matches + if batch == self.prev_cameractrl_hidden_states_batch: + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect + # if does not match, reset cached temp_cameractrl and recalculate it + del self.temp_cameractrl_effect + self.temp_cameractrl_effect = None + # otherwise, calculate temp_cameractrl + self.prev_cameractrl_hidden_states_batch = batch + mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = repeat_to_batch_size(mask, self.full_length) + # if mask not the same amount length as full length, make it match + if self.full_length != mask.shape[0]: + mask = broadcast_image_to(mask, self.full_length, 1) + # reshape mask to attention K shape (h*w, latent_count, 1) + batch, channel, height, width = mask.shape + # first, perform same operations as on hidden_states, + # turning (b, c, h, w) -> (b, h*w, c) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + # then, make it the same shape as attention's k, (h*w, b, c) + mask = mask.permute(1, 0, 2) + # make masks match the expected length of h*w + batched_number = shape[0] // self.video_length + if batched_number > 1: + mask = torch.cat([mask] * batched_number, dim=0) + # cache mask and set to proper device + self.temp_cameractrl_effect = mask + # move temp_cameractrl to proper dtype + device + self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) + # return subset of masks, if needed + if self.sub_idxs is not None: + return self.temp_cameractrl_effect[:, self.sub_idxs, :] + return self.temp_cameractrl_effect + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None, mm_kwargs: dict[str]=None): batch, channel, height, width = hidden_states.shape residual = hidden_states + cameractrl_effect = self.get_cameractrl_effect(hidden_states) + scale_mask = self.get_scale_mask(hidden_states) # add some casts for fp8 purposes - does not affect speed otherwise hidden_states = self.norm(hidden_states).to(hidden_states.dtype) @@ -41,7 +90,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None attention_mask=attention_mask, video_length=self.video_length, scale_mask=scale_mask, - view_options=view_options + cameractrl_effect=cameractrl_effect, + view_options=view_options, + mm_kwargs=mm_kwargs ) # output @@ -67,6 +118,8 @@ def forward( attention_mask=None, video_length=None, scale_mask=None, + cameractrl_effect= 1.0, + mm_kwargs: dict[str]={}, ): if self.attention_mode != "Temporal": raise NotImplementedError @@ -89,6 +142,9 @@ def forward( if encoder_hidden_states is not None else encoder_hidden_states ) + if self.camera_feature_enabled and self.qkv_merge is not None and mm_kwargs is not None and "camera_feature" in mm_kwargs: + camera_feature: torch.Tensor = mm_kwargs["camera_feature"] + hidden_states = (self.qkv_merge(hidden_states + camera_feature) + hidden_states) * cameractrl_effect + hidden_states * (1. - cameractrl_effect) # hidden_states = super().forward( # hidden_states,