Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade diffusers to 0.31.dev to support flux and cogvideox at the same time #268

Merged
merged 1 commit into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def main():
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image.save(f"./results/flux_result_{parallel_info}_{image_rank}.png")
print(
f"image {i} saved to ./results/flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
)
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def get_cuda_version():
author_email="fangjiarui123@gmail.com",
packages=find_packages(),
install_requires=[
"torch>=2.3.0",
"accelerate==0.33.0",
"diffusers==0.30.2",
"torch>=2.1.0",
"accelerate>=0.33.0",
"diffusers @ git+https://github.com/huggingface/diffusers.git",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
Expand Down
10 changes: 2 additions & 8 deletions tests/layers/attention_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
AttentionProcessor,
FusedHunyuanAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FluxSingleAttnProcessor2_0,
)
from xfuser.model_executor.layers.attention_processor import (
xFuserHunyuanAttnProcessor2_0,
xFuserFluxSingleAttnProcessor2_0,
)

from xfuser.core.cache_manager.cache_manager import get_cache_manager
Expand Down Expand Up @@ -56,10 +54,6 @@ def run_attn_test(rank, world_size, attn_type: str):

_type_dict = {
"HunyuanDiT": (HunyuanAttnProcessor2_0(), xFuserHunyuanAttnProcessor2_0()),
"FluxSingle": (
FluxSingleAttnProcessor2_0(),
xFuserFluxSingleAttnProcessor2_0(),
),
}
processor, parallel_processor = _type_dict[attn_type]

Expand Down Expand Up @@ -140,7 +134,7 @@ def run_attn_test(rank, world_size, attn_type: str):
), "Outputs are not close"


@pytest.mark.parametrize("attn_type", ["HunyuanDiT", "FluxSingle"])
@pytest.mark.parametrize("attn_type", ["HunyuanDiT"])
def test_multi_process(attn_type):
world_size = 4 # Number of processes
processes = []
Expand All @@ -160,4 +154,4 @@ def test_multi_process(attn_type):


if __name__ == "__main__":
test_multi_process("FluxSingle")
test_multi_process("HunyuanDiT")
11 changes: 11 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def check_env():
"with `pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cu121`"
)
try:
import diffusers

if version.parse(diffusers.__version__) > version.parse("0.30.2"):
raise RuntimeError(
"This project requires diffusers version >= 0.31.0 or above. It is not on pypi. Please install it from source code!"
)
except ImportError:
raise ImportError(
"diffusers is not installed. Please install it with `pip install diffusers`"
)


@dataclass
Expand Down
60 changes: 43 additions & 17 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
AttnProcessor2_0,
JointAttnProcessor2_0,
FluxAttnProcessor2_0,
FluxSingleAttnProcessor2_0,
HunyuanAttnProcessor2_0,
)

try:
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0
except ImportError:
Expand Down Expand Up @@ -80,14 +80,20 @@ def apply_rotary_emb(

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(
-1
) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(
-2
) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
raise ValueError(
f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2."
)

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

Expand Down Expand Up @@ -343,8 +349,9 @@ def __call__(
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if (HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
if (
HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
and not latte_temporal_attention
):
query = query.transpose(1, 2)
Expand Down Expand Up @@ -638,7 +645,11 @@ def __call__(
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

# `sample` projections.
query = attn.to_q(hidden_states)
Expand Down Expand Up @@ -675,9 +686,13 @@ def __call__(
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)

num_encoder_hidden_states_tokens = encoder_hidden_states_query_proj.shape[2]
num_query_tokens = query.shape[2]
Expand Down Expand Up @@ -808,7 +823,7 @@ def __init__(self):
)
else:
self.hybrid_seq_parallel_attn = None

# NOTE() torch.compile dose not works for V100
@torch_compile_disable_if_v100
def __call__(
Expand Down Expand Up @@ -1029,12 +1044,18 @@ def __call__(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
Expand All @@ -1054,9 +1075,13 @@ def __call__(

# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
Expand All @@ -1082,7 +1107,9 @@ def __call__(
causal=False,
joint_strategy="none",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -1114,7 +1141,6 @@ def __call__(
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------


# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
Expand Down