Skip to content

Commit

Permalink
Support optimized USP in Flux (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Nov 28, 2024
1 parent 403f4e5 commit ca94011
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 21 deletions.
11 changes: 6 additions & 5 deletions examples/flux_usp_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux inference with USP
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py

import functools
from typing import List, Optional, Tuple, Union
from typing import List, Optional

import logging
import time
Expand All @@ -25,7 +28,7 @@
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP
from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
Expand All @@ -40,8 +43,6 @@ def new_forward(
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
controlnet_block_samples: Optional[List[torch.Tensor]] = None,
controlnet_single_block_samples: Optional[List[torch.Tensor]] = None,
**kwargs,
):
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
Expand All @@ -54,7 +55,7 @@ def new_forward(
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor2_0USP()
block.attn.processor = xFuserFluxAttnProcessor2_0()

output = original_forward(
hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_cuda_version():
install_requires=[
"torch>=2.1.0",
"accelerate>=0.33.0",
"diffusers==0.31", # NOTE: diffusers>=0.31.0 is necessary for CogVideoX and Flux
"diffusers@git+https://github.com/huggingface/diffusers", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
Expand Down
15 changes: 12 additions & 3 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
get_pipeline_parallel_world_size
)
from xfuser.core.fast_attention import (
xFuserFastAttention,
Expand All @@ -34,6 +33,9 @@
from xfuser.logger import init_logger
from xfuser.envs import PACKAGES_CHECKER

if torch.__version__ >= '2.5.0':
from xfuser.model_executor.layers.usp import USP

logger = init_logger(__name__)

env_info = PACKAGES_CHECKER.get_packages_info()
Expand Down Expand Up @@ -687,7 +689,14 @@ def __call__(
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_pipeline_parallel_world_size() == 1 and torch.__version__ >= '2.5.0' and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
Expand Down
7 changes: 4 additions & 3 deletions xfuser/model_executor/pipelines/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,10 @@ def _init_sync_pipeline(
latents = super()._init_video_sync_pipeline(latents)

if get_runtime_state().split_text_embed_in_sp:
assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \
f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}"
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False

if image_rotary_emb is not None:
assert latents_frames is not None
Expand Down
14 changes: 8 additions & 6 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,16 @@ def _init_sync_pipeline(
latent_image_ids = torch.cat(latent_image_ids_list, dim=-2)

if get_runtime_state().split_text_embed_in_sp:
assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \
f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}"
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False

if get_runtime_state().split_text_embed_in_sp:
assert text_ids.shape[-2] % get_sequence_parallel_world_size() == 0, \
f"the length of text sequence {text_ids.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}"
text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
if text_ids.shape[-2] % get_sequence_parallel_world_size() == 0:
text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False

return latents, latent_image_ids, prompt_embeds, text_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,10 @@ def _init_sync_pipeline(self, latents: torch.Tensor, prompt_embeds: torch.Tensor
latents = torch.cat(latents_list, dim=-2)

if get_runtime_state().split_text_embed_in_sp:
assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \
f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}"
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0:
prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
else:
get_runtime_state().split_text_embed_in_sp = False

return latents, prompt_embeds

Expand Down

0 comments on commit ca94011

Please sign in to comment.