Skip to content

Commit

Permalink
upgrade diffusers to 0.31.dev version to support flux and cogvideox s…
Browse files Browse the repository at this point in the history
…imutanously (#268)
  • Loading branch information
feifeibear authored Sep 14, 2024
1 parent da885f4 commit 2e2ad16
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 32 deletions.
7 changes: 3 additions & 4 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ def main():
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image.save(f"./results/flux_result_{parallel_info}_{image_rank}.png")
print(
f"image {i} saved to ./results/flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
)
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def get_cuda_version():
author_email="fangjiarui123@gmail.com",
packages=find_packages(),
install_requires=[
"torch>=2.3.0",
"accelerate==0.33.0",
"diffusers==0.30.2",
"torch>=2.1.0",
"accelerate>=0.33.0",
"diffusers @ git+https://github.com/huggingface/diffusers.git",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
Expand Down
10 changes: 2 additions & 8 deletions tests/layers/attention_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
AttentionProcessor,
FusedHunyuanAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FluxSingleAttnProcessor2_0,
)
from xfuser.model_executor.layers.attention_processor import (
xFuserHunyuanAttnProcessor2_0,
xFuserFluxSingleAttnProcessor2_0,
)

from xfuser.core.cache_manager.cache_manager import get_cache_manager
Expand Down Expand Up @@ -56,10 +54,6 @@ def run_attn_test(rank, world_size, attn_type: str):

_type_dict = {
"HunyuanDiT": (HunyuanAttnProcessor2_0(), xFuserHunyuanAttnProcessor2_0()),
"FluxSingle": (
FluxSingleAttnProcessor2_0(),
xFuserFluxSingleAttnProcessor2_0(),
),
}
processor, parallel_processor = _type_dict[attn_type]

Expand Down Expand Up @@ -140,7 +134,7 @@ def run_attn_test(rank, world_size, attn_type: str):
), "Outputs are not close"


@pytest.mark.parametrize("attn_type", ["HunyuanDiT", "FluxSingle"])
@pytest.mark.parametrize("attn_type", ["HunyuanDiT"])
def test_multi_process(attn_type):
world_size = 4 # Number of processes
processes = []
Expand All @@ -160,4 +154,4 @@ def test_multi_process(attn_type):


if __name__ == "__main__":
test_multi_process("FluxSingle")
test_multi_process("HunyuanDiT")
11 changes: 11 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def check_env():
"with `pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cu121`"
)
try:
import diffusers

if version.parse(diffusers.__version__) > version.parse("0.30.2"):
raise RuntimeError(
"This project requires diffusers version >= 0.31.0 or above. It is not on pypi. Please install it from source code!"
)
except ImportError:
raise ImportError(
"diffusers is not installed. Please install it with `pip install diffusers`"
)


@dataclass
Expand Down
60 changes: 43 additions & 17 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
AttnProcessor2_0,
JointAttnProcessor2_0,
FluxAttnProcessor2_0,
FluxSingleAttnProcessor2_0,
HunyuanAttnProcessor2_0,
)

try:
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0
except ImportError:
Expand Down Expand Up @@ -80,14 +80,20 @@ def apply_rotary_emb(

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(
-1
) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(
-2
) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
raise ValueError(
f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2."
)

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

Expand Down Expand Up @@ -343,8 +349,9 @@ def __call__(
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if (HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
if (
HAS_LONG_CTX_ATTN
and get_sequence_parallel_world_size() > 1
and not latte_temporal_attention
):
query = query.transpose(1, 2)
Expand Down Expand Up @@ -638,7 +645,11 @@ def __call__(
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

# `sample` projections.
query = attn.to_q(hidden_states)
Expand Down Expand Up @@ -675,9 +686,13 @@ def __call__(
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)

num_encoder_hidden_states_tokens = encoder_hidden_states_query_proj.shape[2]
num_query_tokens = query.shape[2]
Expand Down Expand Up @@ -808,7 +823,7 @@ def __init__(self):
)
else:
self.hybrid_seq_parallel_attn = None

# NOTE() torch.compile dose not works for V100
@torch_compile_disable_if_v100
def __call__(
Expand Down Expand Up @@ -1029,12 +1044,18 @@ def __call__(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
Expand All @@ -1054,9 +1075,13 @@ def __call__(

# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
Expand All @@ -1082,7 +1107,9 @@ def __call__(
causal=False,
joint_strategy="none",
)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
else:
if HAS_FLASH_ATTN:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -1114,7 +1141,6 @@ def __call__(
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------


# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
Expand Down

0 comments on commit 2e2ad16

Please sign in to comment.