Skip to content

Commit

Permalink
Revert "CogVideoX support with USP (xdit-project#261)"
Browse files Browse the repository at this point in the history
This reverts commit 9484590.
  • Loading branch information
Eigensystem authored and feifeibear committed Oct 25, 2024
1 parent 7f174dc commit 69a36d3
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 302 deletions.
40 changes: 23 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,40 @@ def get_cuda_version():
except Exception as e:
return 'no_cuda'

def get_install_requires(cuda_version):
if cuda_version == 'cu124':
sys.stderr.write("WARNING: Manual installation required for CUDA 12.4 specific PyTorch version.\n")
sys.stderr.write("Please install PyTorch for CUDA 12.4 using the following command:\n")
sys.stderr.write("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n")

return [
"torch==2.3.0",
"diffusers>=0.30.0",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"accelerate==0.33.0",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.3",
"flash_attn>=2.6.3",
"pytest",
"flask",
]

if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("xfuser/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])

cuda_version = get_cuda_version()

setup(
name="xfuser",
author="xDiT Team",
author_email="fangjiarui123@gmail.com",
packages=find_packages(),
install_requires=[
"torch>=2.3.0",
"accelerate==0.33.0",
"diffusers>=0.30.0",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.3",
"pytest",
"flask",
],
extras_require={
"all": [
"flash_attn>=2.6.3",
],
},
install_requires=get_install_requires(cuda_version),
url="https://github.com/xdit-project/xDiT.",
description="xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters",
long_description=long_description,
Expand Down
13 changes: 7 additions & 6 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def set_input_parameters(
self.input_config.seed = seed
set_random_seed(seed)
if (
not self.ready
or (height and self.input_config.height != height)
(height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (batch_size and self.input_config.batch_size != batch_size)
or not self.ready
Expand All @@ -164,8 +163,7 @@ def set_video_input_parameters(
self.input_config.seed = seed
set_random_seed(seed)
if (
not self.ready
or (height and self.input_config.height != height)
(height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (num_frames and self.input_config.num_frames != num_frames)
or (batch_size and self.input_config.batch_size != batch_size)
Expand Down Expand Up @@ -363,6 +361,7 @@ def _calc_patches_metadata(self):
self.pp_patches_token_num = pp_patches_token_num

def _calc_cogvideox_patches_metadata(self):

num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
Expand Down Expand Up @@ -451,9 +450,11 @@ def _calc_cogvideox_patches_metadata(self):
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size)
* (start_idx // patch_size),
* (start_idx // patch_size)
* latents_frames,
(latents_width // patch_size)
* (end_idx // patch_size),
* (end_idx // patch_size)
* latents_frames,
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
Expand Down
6 changes: 3 additions & 3 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,16 @@ def forward(
key: Tensor,
value: Tensor,
*,
joint_tensor_query,
joint_tensor_key,
joint_tensor_value,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
joint_tensor_query=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
) -> Tensor:
"""forward
Expand Down
221 changes: 14 additions & 207 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Optional, Union, Tuple
from typing import Optional

import torch
from torch import nn
Expand All @@ -12,12 +12,9 @@
JointAttnProcessor2_0,
FluxAttnProcessor2_0,
FluxSingleAttnProcessor2_0,
apply_rope,
HunyuanAttnProcessor2_0,
)
try:
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0
except ImportError:
CogVideoXAttnProcessor2_0 = None

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
Expand Down Expand Up @@ -45,62 +42,11 @@ def is_v100():
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
return "V100" in device_name


def torch_compile_disable_if_v100(func):
if is_v100():
return torch.compiler.disable(func)
return func


def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

return x_out.type_as(x)


class xFuserAttentionBaseWrapper(xFuserLayerBaseWrapper):
def __init__(
self,
Expand Down Expand Up @@ -702,8 +648,11 @@ def __call__(
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
Expand Down Expand Up @@ -875,8 +824,11 @@ def __call__(

# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
Expand Down Expand Up @@ -980,6 +932,8 @@ def __call__(
image_rotary_emb: Optional[torch.Tensor] = None,
latte_temporal_attention: Optional[bool] = False,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
Expand Down Expand Up @@ -1136,150 +1090,3 @@ def __call__(
hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


if CogVideoXAttnProcessor2_0 is not None:

@xFuserAttentionProcessorRegister.register(CogVideoXAttnProcessor2_0)
class xFuserCogVideoXAttnProcessor2_0(CogVideoXAttnProcessor2_0):
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""

def __init__(self):
super().__init__()
use_long_ctx_attn_kvcache = True
self.use_long_ctx_attn_kvcache = (
HAS_LONG_CTX_ATTN
and use_long_ctx_attn_kvcache
and get_sequence_parallel_world_size() > 1
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
use_kv_cache=self.use_long_ctx_attn_kvcache,
)
else:
self.hybrid_seq_parallel_attn = None

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
layer_type="attn",
)
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_strategy="none",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)

else:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)

#! ORIGIN
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------


# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
Loading

0 comments on commit 69a36d3

Please sign in to comment.