Skip to content

Commit

Permalink
Merge branch 'main' into 1011
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 11, 2024
2 parents 80a3331 + f895f3b commit 39698f3
Show file tree
Hide file tree
Showing 9 changed files with 445 additions and 113 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Furthermore, xDiT incorporates optimization techniques from [DiTFastAttn](https:
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ |
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ |
| [🟢 PixArt-alpha](https://huggingface.co/PixArt-alpha/PixArt-alpha) | ✔️ | ✔️ | ✔️ |
| [🟠 Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | ✔️ | ✔️ | ✔️ |
Expand Down
10 changes: 5 additions & 5 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class DiTRuntimeState(RuntimeState):
pp_patches_token_start_idx_local: Optional[List[int]]
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
max_condition_sequence_length: int

def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
Expand Down Expand Up @@ -126,10 +127,12 @@ def set_input_parameters(
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
max_condition_sequence_length: Optional[int] = None,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
self.max_condition_sequence_length = max_condition_sequence_length
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
if seed is not None and seed != self.input_config.seed:
Expand All @@ -140,7 +143,6 @@ def set_input_parameters(
or (height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (batch_size and self.input_config.batch_size != batch_size)
or not self.ready
):
self._input_size_change(height, width, batch_size)

Expand Down Expand Up @@ -450,10 +452,8 @@ def _calc_cogvideox_patches_metadata(self):
]
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size)
* (start_idx // patch_size),
(latents_width // patch_size)
* (end_idx // patch_size),
(latents_width // patch_size) * (start_idx // patch_size),
(latents_width // patch_size) * (end_idx // patch_size),
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
Expand Down
8 changes: 5 additions & 3 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
assert (to_k.bias is None) == (to_v.bias is None)
assert to_k.weight.shape == to_v.weight.shape


class xFuserAttentionProcessorRegister:
_XFUSER_ATTENTION_PROCESSOR_MAPPING = {}

Expand Down Expand Up @@ -698,8 +699,10 @@ def __call__(
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
num_encoder_hidden_states_tokens = 0
num_query_tokens = query.shape[2]
num_encoder_hidden_states_tokens = (
get_runtime_state().max_condition_sequence_length
)
num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens

if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
Expand Down Expand Up @@ -1158,7 +1161,6 @@ def __call__(
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, latent_seq_length], dim=1
)
Expand Down
111 changes: 78 additions & 33 deletions xfuser/model_executor/models/transformers/base_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod, ABCMeta
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.nn as nn
Expand All @@ -19,6 +20,11 @@
logger = init_logger(__name__)


class StageInfo:
def __init__(self):
self.after_flags: Dict[str, bool] = {}


class xFuserTransformerBaseWrapper(xFuserModelBaseWrapper, metaclass=ABCMeta):
# transformer: original transformer model (for example Transformer2DModel)
def __init__(
Expand All @@ -27,12 +33,15 @@ def __init__(
submodule_classes_to_wrap: List[Type] = [],
submodule_name_to_wrap: List = [],
submodule_addition_args: Dict = {},
transformer_blocks_name: List[str] = ["transformer_blocks"],
):
self.stage_info = None
transformer = self._convert_transformer_for_parallel(
transformer,
submodule_classes_to_wrap=submodule_classes_to_wrap,
submodule_name_to_wrap=submodule_name_to_wrap,
submodule_addition_args=submodule_addition_args,
transformer_blocks_name=transformer_blocks_name,
)
super().__init__(module=transformer)

Expand All @@ -42,6 +51,7 @@ def _convert_transformer_for_parallel(
submodule_classes_to_wrap: List[Type] = [],
submodule_name_to_wrap: List = [],
submodule_addition_args: Dict = {},
transformer_blocks_name: List[str] = [],
) -> nn.Module:
if (
get_pipeline_parallel_world_size() == 1
Expand All @@ -51,7 +61,9 @@ def _convert_transformer_for_parallel(
):
return transformer
else:
transformer = self._split_transformer_blocks(transformer)
transformer = self._split_transformer_blocks(
transformer, transformer_blocks_name
)
transformer = self._wrap_layers(
model=transformer,
submodule_classes_to_wrap=submodule_classes_to_wrap,
Expand All @@ -64,54 +76,87 @@ def _convert_transformer_for_parallel(
def _split_transformer_blocks(
self,
transformer: nn.Module,
blocks_name: List[str] = [],
):
if not hasattr(transformer, "transformer_blocks"):
raise AttributeError(
f"'{transformer.__class__.__name__}' object has no attribute "
f"'transformer_blocks'. To use pipeline parallelism with"
f"object {transformer.__class__.__name__}, please implement "
f"custom _split_transformer_blocks method in derived class"
)
for block_name in blocks_name:
if not hasattr(transformer, block_name):
raise AttributeError(
f"'{transformer.__class__.__name__}' object has no attribute "
f"'{block_name}'."
)

# transformer layer split
attn_layer_num_for_pp = (
get_runtime_state().parallel_config.pp_config.attn_layer_num_for_pp
)
pp_rank = get_pipeline_parallel_rank()
pp_world_size = get_pipeline_parallel_world_size()
blocks_list = {
block_name: getattr(transformer, block_name) for block_name in blocks_name
}
num_blocks_list = [len(blocks) for blocks in blocks_list.values()]
self.blocks_idx = {
name: [sum(num_blocks_list[:i]), sum(num_blocks_list[: i + 1])]
for i, name in enumerate(blocks_name)
}
if attn_layer_num_for_pp is not None:
assert sum(attn_layer_num_for_pp) == len(transformer.transformer_blocks), (
assert sum(attn_layer_num_for_pp) == sum(num_blocks_list), (
"Sum of attn_layer_num_for_pp should be equal to the "
"number of transformer blocks"
"number of all the transformer blocks"
)
if is_pipeline_first_stage():
transformer.transformer_blocks = transformer.transformer_blocks[
: attn_layer_num_for_pp[0]
]
else:
transformer.transformer_blocks = transformer.transformer_blocks[
sum(attn_layer_num_for_pp[: pp_rank - 1]) : sum(
attn_layer_num_for_pp[:pp_rank]
)
]
stage_block_start_idx = sum(attn_layer_num_for_pp[:pp_rank])
stage_block_end_idx = sum(attn_layer_num_for_pp[: pp_rank + 1])

else:
num_blocks_per_stage = (
len(transformer.transformer_blocks) + pp_world_size - 1
sum(num_blocks_list) + pp_world_size - 1
) // pp_world_size
start_idx = pp_rank * num_blocks_per_stage
end_idx = min(
stage_block_start_idx = pp_rank * num_blocks_per_stage
stage_block_end_idx = min(
(pp_rank + 1) * num_blocks_per_stage,
len(transformer.transformer_blocks),
sum(num_blocks_list),
)
transformer.transformer_blocks = transformer.transformer_blocks[
start_idx:end_idx
]
# position embedding
if not is_pipeline_first_stage():
transformer.pos_embed = None
if not is_pipeline_last_stage():
transformer.norm_out = None
transformer.proj_out = None

self.stage_info = StageInfo()
for name, [blocks_start, blocks_end] in zip(
self.blocks_idx.keys(), self.blocks_idx.values()
):
if (
blocks_end <= stage_block_start_idx
or stage_block_end_idx <= blocks_start
):
setattr(transformer, name, nn.ModuleList([]))
self.stage_info.after_flags[name] = False
elif stage_block_start_idx <= blocks_start:
if blocks_end <= stage_block_end_idx:
self.stage_info.after_flags[name] = True
else:
setattr(
transformer,
name,
blocks_list[name][: -(blocks_end - stage_block_end_idx)],
)
self.stage_info.after_flags[name] = False
elif blocks_start < stage_block_start_idx:
if blocks_end <= stage_block_end_idx:
setattr(
transformer,
name,
blocks_list[name][stage_block_start_idx - blocks_start :],
)
self.stage_info.after_flags[name] = True
else: # blocks_end > stage_layer_end_idx
setattr(
transformer,
name,
blocks_list[name][
stage_block_start_idx
- blocks_start : stage_block_end_idx
- blocks_end
],
)
self.stage_info.after_flags[name] = False

return transformer

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict, Any
from typing import List, Optional, Dict, Any
import torch
import torch.distributed
import torch.nn as nn
Expand Down Expand Up @@ -41,6 +41,7 @@ def __init__(
def _split_transformer_blocks(
self,
transformer: nn.Module,
blocks_name: List[str] = [],
):
if not hasattr(transformer, "blocks"):
raise AttributeError(
Expand Down
52 changes: 37 additions & 15 deletions xfuser/model_executor/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
unscale_lora_layers,
)

from xfuser.core.distributed.parallel_state import get_tensor_model_parallel_world_size
from xfuser.core.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.logger import init_logger
from xfuser.model_executor.models.transformers.register import (
Expand All @@ -39,6 +43,7 @@ def __init__(
[FeedForward] if get_tensor_model_parallel_world_size() > 1 else []
),
submodule_name_to_wrap=["attn"],
transformer_blocks_name=["transformer_blocks", "single_transformer_blocks"],
)
self.encoder_hidden_states_cache = [
None for _ in range(len(self.transformer_blocks))
Expand Down Expand Up @@ -92,11 +97,16 @@ def forward(
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)

if is_pipeline_first_stage():
hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
Expand All @@ -108,7 +118,8 @@ def forward(
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if is_pipeline_first_stage():
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if txt_ids.ndim == 3:
logger.warning(
Expand Down Expand Up @@ -138,14 +149,18 @@ def custom_forward(*inputs):

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
)

else:
Expand All @@ -162,6 +177,7 @@ def custom_forward(*inputs):
# interval_control = int(np.ceil(interval_control))
# hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]

# if self.stage_info.after_flags["transformer_blocks"]:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
Expand All @@ -176,7 +192,9 @@ def custom_forward(*inputs):

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
Expand All @@ -201,10 +219,14 @@ def custom_forward(*inputs):
# + controlnet_single_block_samples[index_block // interval_control]
# )

encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if self.stage_info.after_flags["single_transformer_blocks"]:
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states), None
else:
output = hidden_states, encoder_hidden_states

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
Expand Down
Loading

0 comments on commit 39698f3

Please sign in to comment.