diff --git a/.gitignore b/.gitignore index a195c1e3..5388a1d9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ profile/ xfuser.egg-info/ dist/* latte_output.mp4 -*.sh \ No newline at end of file +*.sh +cache/ \ No newline at end of file diff --git a/examples/run.sh b/examples/run.sh index 463feec3..212c9f89 100644 --- a/examples/run.sh +++ b/examples/run.sh @@ -19,7 +19,7 @@ export PYTHONPATH=$PWD:$PYTHONPATH # or you can simply use the model's ID on Hugging Face, # which will then be downloaded to the default cache path on Hugging Face. -export MODEL_TYPE="CogVideoX" +export MODEL_TYPE="Pixart-alpha" # Configuration for different model types # script, model_id, inference_step declare -A MODEL_CONFIGS=( @@ -53,24 +53,42 @@ if [ "$MODEL_TYPE" = "Flux" ]; then N_GPUS=8 PARALLEL_ARGS="--ulysses_degree $N_GPUS" CFG_ARGS="" +FAST_ATTN_ARGS="" # CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree. elif [ "$MODEL_TYPE" = "CogVideoX" ]; then N_GPUS=4 PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1" CFG_ARGS="--use_cfg_parallel" +FAST_ATTN_ARGS="" # HunyuanDiT asserts sp_degree == ulysses_degree*ring_degree <= 2, or the output will be incorrect. elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then N_GPUS=8 PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1" CFG_ARGS="--use_cfg_parallel" +FAST_ATTN_ARGS="" + +# Pixart-alpha can use DiTFastAttn to compression attention module, but DiTFastAttn can only use with data parallel +elif [ "$MODEL_TYPE" = "Pixart-alpha" ]; then +N_GPUS=4 +PARALLEL_ARGS="--data_parallel_degree $N_GPUS" +CFG_ARGS="" +FAST_ATTN_ARGS="--use_fast_attn --window_size 512 --n_calib 4 --threshold 0.15 --use_cache --coco_path /data/mscoco/annotations/captions_val2014.json" + +# Pixart-sigma can use DiTFastAttn to compression attention module, but DiTFastAttn can only use with data parallel +elif [ "$MODEL_TYPE" = "Pixart-sigma" ]; then +N_GPUS=4 +PARALLEL_ARGS="--data_parallel_degree $N_GPUS" +CFG_ARGS="" +FAST_ATTN_ARGS="--use_fast_attn --window_size 512 --n_calib 4 --threshold 0.15 --use_cache --coco_path /data/mscoco/annotations/captions_val2014.json" else # On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch) N_GPUS=8 PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1" CFG_ARGS="--use_cfg_parallel" +FAST_ATTN_ARGS="" fi @@ -95,5 +113,6 @@ $OUTPUT_ARGS \ --warmup_steps 0 \ --prompt "A small dog" \ $CFG_ARGS \ +$FAST_ATTN_ARGS \ $PARALLLEL_VAE \ $COMPILE_FLAG diff --git a/xfuser/config/args.py b/xfuser/config/args.py index 7ca8b551..2ef59a02 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -11,6 +11,7 @@ from xfuser.core.distributed import init_distributed_environment from xfuser.config.config import ( EngineConfig, + FastAttnConfig, ParallelConfig, TensorParallelConfig, PipeFusionParallelConfig, @@ -94,6 +95,13 @@ class xFuserArgs: seed: int = 42 output_type: str = "pil" enable_sequential_cpu_offload: bool = False + # DiTFastAttn arguments + use_fast_attn: bool = False + n_calib: int = 8 + threshold: float = 0.5 + window_size: int = 64 + coco_path: Optional[str] = None + use_cache: bool = False @staticmethod def add_cli_args(parser: FlexibleArgumentParser): @@ -240,6 +248,43 @@ def add_cli_args(parser: FlexibleArgumentParser): help="Offloading the weights to the CPU.", ) + # DiTFastAttn arguments + fast_attn_group = parser.add_argument_group("DiTFastAttn Options") + fast_attn_group.add_argument( + "--use_fast_attn", + action="store_true", + help="Use DiTFastAttn to accelerate single inference. Only data parallelism can be used with DITFastAttn.", + ) + fast_attn_group.add_argument( + "--n_calib", + type=int, + default=8, + help="Number of prompts for compression method seletion.", + ) + fast_attn_group.add_argument( + "--threshold", + type=float, + default=0.5, + help="Threshold for selecting attention compression method.", + ) + fast_attn_group.add_argument( + "--window_size", + type=int, + default=64, + help="Size of window attention.", + ) + fast_attn_group.add_argument( + "--coco_path", + type=str, + default=None, + help="Path of MS COCO annotation json file.", + ) + fast_attn_group.add_argument( + "--use_cache", + action="store_true", + help="Use cache config for attention compression.", + ) + return parser @classmethod @@ -294,10 +339,21 @@ def create_config( ), ) + fast_attn_config = FastAttnConfig( + use_fast_attn=self.use_fast_attn, + n_step=self.num_inference_steps, + n_calib=self.n_calib, + threshold=self.threshold, + window_size=self.window_size, + coco_path=self.coco_path, + use_cache=self.use_cache, + ) + engine_config = EngineConfig( model_config=model_config, runtime_config=runtime_config, parallel_config=parallel_config, + fast_attn_config=fast_attn_config, ) input_config = InputConfig( diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 5de13653..44944912 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -66,6 +66,21 @@ def __post_init__(self): check_env() +@dataclass +class FastAttnConfig: + use_fast_attn: bool = False + n_step: int = 20 + n_calib: int = 8 + threshold: float = 0.5 + window_size: int = 64 + coco_path: Optional[str] = None + use_cache: bool = False + + def __post_init__(self): + assert self.n_calib > 0, "n_calib must be greater than 0" + assert self.threshold > 0.0, "threshold must be greater than 0" + + @dataclass class DataParallelConfig: dp_degree: int = 1 @@ -217,6 +232,12 @@ class EngineConfig: model_config: ModelConfig runtime_config: RuntimeConfig parallel_config: ParallelConfig + fast_attn_config: FastAttnConfig + + def __post_init__(self): + world_size = dist.get_world_size() + if self.fast_attn_config.use_fast_attn: + assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn" def to_dict(self): """Return the configs as a dictionary, for use in **kwargs.""" diff --git a/xfuser/core/fast_attention/__init__.py b/xfuser/core/fast_attention/__init__.py new file mode 100644 index 00000000..55b182fa --- /dev/null +++ b/xfuser/core/fast_attention/__init__.py @@ -0,0 +1,37 @@ +from .fast_attn_state import ( + get_fast_attn_state, + get_fast_attn_enable, + get_fast_attn_step, + get_fast_attn_calib, + get_fast_attn_threshold, + get_fast_attn_window_size, + get_fast_attn_coco_path, + get_fast_attn_use_cache, + get_fast_attn_config_file, + get_fast_attn_layer_name, + initialize_fast_attn_state, +) + +from .attn_layer import ( + FastAttnMethod, + xFuserFastAttention, +) + +from .utils import fast_attention_compression + +__all__ = [ + "get_fast_attn_state", + "get_fast_attn_enable", + "get_fast_attn_step", + "get_fast_attn_calib", + "get_fast_attn_threshold", + "get_fast_attn_window_size", + "get_fast_attn_coco_path", + "get_fast_attn_use_cache", + "get_fast_attn_config_file", + "get_fast_attn_layer_name", + "initialize_fast_attn_state", + "xFuserFastAttention", + "FastAttnMethod", + "fast_attention_compression", +] diff --git a/xfuser/core/fast_attention/attn_layer.py b/xfuser/core/fast_attention/attn_layer.py new file mode 100644 index 00000000..e82c4bfa --- /dev/null +++ b/xfuser/core/fast_attention/attn_layer.py @@ -0,0 +1,220 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/thu-nics/DiTFastAttn/blob/main/modules/fast_attn_processor.py +# Copyright (c) 2024 NICS-EFC Lab of Tsinghua University. + +import torch +from diffusers.models.attention_processor import Attention +from typing import Optional +import torch.nn.functional as F +import flash_attn +from enum import Flag, auto +from .fast_attn_state import get_fast_attn_window_size + + +class FastAttnMethod(Flag): + FULL_ATTN = auto() + RESIDUAL_WINDOW_ATTN = auto() + OUTPUT_SHARE = auto() + CFG_SHARE = auto() + RESIDUAL_WINDOW_ATTN_CFG_SHARE = RESIDUAL_WINDOW_ATTN | CFG_SHARE + FULL_ATTN_CFG_SHARE = FULL_ATTN | CFG_SHARE + + def has(self, method: "FastAttnMethod"): + return bool(self & method) + + +class xFuserFastAttention: + window_size: list[int] = [-1, -1] + steps_method: list[FastAttnMethod] = [] + cond_first: bool = False + need_compute_residual: list[bool] = [] + need_cache_output: bool = False + + def __init__( + self, + steps_method: list[FastAttnMethod] = [], + cond_first: bool = False, + ): + window_size = get_fast_attn_window_size() + self.window_size = [window_size, window_size] + self.steps_method = steps_method + # CFG order flag (conditional first or unconditional first) + self.cond_first = cond_first + self.need_compute_residual = self.compute_need_compute_residual() + self.need_cache_output = True + + def set_methods( + self, + steps_method: list[FastAttnMethod], + selecting: bool = False, + ): + self.steps_method = steps_method + if selecting: + if len(self.need_compute_residual) != len(self.steps_method): + self.need_compute_residual = [False] * len(self.steps_method) + else: + self.need_compute_residual = self.compute_need_compute_residual() + + def compute_need_compute_residual(self): + """Check at which timesteps do we need to compute the full-window residual of this attention module""" + need_compute_residual = [] + for i, method in enumerate(self.steps_method): + need = False + if method.has(FastAttnMethod.FULL_ATTN): + for j in range(i + 1, len(self.steps_method)): + if self.steps_method[j].has(FastAttnMethod.RESIDUAL_WINDOW_ATTN): + # If encountered a step that conduct WA-RS, + # this step needs the residual computation + need = True + if self.steps_method[j].has(FastAttnMethod.FULL_ATTN): + # If encountered another step using the `full-attn` strategy, + # this step doesn't need the residual computation + break + need_compute_residual.append(need) + return need_compute_residual + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + + # Before calculating the attention, prepare the related parameters + method = self.steps_method[attn.stepi] if attn.stepi < len(self.steps_method) else FastAttnMethod.FULL_ATTN + need_compute_residual = self.need_compute_residual[attn.stepi] if attn.stepi < len(self.need_compute_residual) else False + + # Run the forward method according to the selected strategy + residual = hidden_states + if method.has(FastAttnMethod.OUTPUT_SHARE): + hidden_states = attn.cached_output + else: + if method.has(FastAttnMethod.CFG_SHARE): + # Directly use the unconditional branch's attention output + # as the conditional branch's attention output + + batch_size = hidden_states.shape[0] + if self.cond_first: + hidden_states = hidden_states[: batch_size // 2] + else: + hidden_states = hidden_states[batch_size // 2 :] + if encoder_hidden_states is not None: + if self.cond_first: + encoder_hidden_states = encoder_hidden_states[: batch_size // 2] + else: + encoder_hidden_states = encoder_hidden_states[batch_size // 2 :] + if attention_mask is not None: + if self.cond_first: + attention_mask = attention_mask[: batch_size // 2] + else: + attention_mask = attention_mask[batch_size // 2 :] + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if attention_mask is not None: + assert ( + method.has(FastAttnMethod.RESIDUAL_WINDOW_ATTN) == False + ), "Attention mask is not supported in windowed attention" + + hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose(1, 2) + elif method.has(FastAttnMethod.FULL_ATTN): + all_hidden_states = flash_attn.flash_attn_func(query, key, value) + if need_compute_residual: + # Compute the full-window attention residual + w_hidden_states = flash_attn.flash_attn_func(query, key, value, window_size=self.window_size) + window_residual = all_hidden_states - w_hidden_states + if method.has(FastAttnMethod.CFG_SHARE): + window_residual = torch.cat([window_residual, window_residual], dim=0) + # Save the residual for usage in follow-up steps + attn.cached_residual = window_residual + hidden_states = all_hidden_states + elif method.has(FastAttnMethod.RESIDUAL_WINDOW_ATTN): + w_hidden_states = flash_attn.flash_attn_func(query, key, value, window_size=self.window_size) + hidden_states = w_hidden_states + attn.cached_residual[:batch_size].view_as(w_hidden_states) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if method.has(FastAttnMethod.CFG_SHARE): + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + + if self.need_cache_output: + attn.cached_output = hidden_states + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + # After been call once, add the timestep index of this attention module by 1 + attn.stepi += 1 + + return hidden_states + + +# TODO: Implement classes to support DiTFastAttn in different diffusion models +class xFuserJointFastAttention(xFuserFastAttention): + pass + + +class xFuserFluxFastAttention(xFuserFastAttention): + pass + + +class xFuserHunyuanFastAttention(xFuserFastAttention): + pass diff --git a/xfuser/core/fast_attention/fast_attn_state.py b/xfuser/core/fast_attention/fast_attn_state.py new file mode 100644 index 00000000..91823c43 --- /dev/null +++ b/xfuser/core/fast_attention/fast_attn_state.py @@ -0,0 +1,111 @@ +from typing import Optional +from diffusers import DiffusionPipeline +from xfuser.config.config import ( + ParallelConfig, + RuntimeConfig, + InputConfig, + FastAttnConfig, + EngineConfig, +) +from xfuser.logger import init_logger + +logger = init_logger(__name__) + + +class FastAttnState: + enable: bool = False + n_step: int = 20 + n_calib: int = 8 + threshold: float = 0.5 + window_size: int = 64 + coco_path: Optional[str] = None + use_cache: bool = False + config_file: str + layer_name: str + + def __init__(self, pipe: DiffusionPipeline, config: FastAttnConfig): + self.enable = config.use_fast_attn + if self.enable: + self.n_step = config.n_step + self.n_calib = config.n_calib + self.threshold = config.threshold + self.window_size = config.window_size + self.coco_path = config.coco_path + self.use_cache = config.use_cache + self.config_file = self.config_file_path(pipe, config) + self.layer_name = self.attn_name_to_wrap(pipe) + + def config_file_path(self, pipe: DiffusionPipeline, config: FastAttnConfig): + """Return the config file path.""" + return f"cache/{pipe.config._name_or_path.replace('/', '_')}_{config.n_step}_{config.n_calib}_{config.threshold}_{config.window_size}.json" + + def attn_name_to_wrap(self, pipe: DiffusionPipeline): + """Return the attr name of attention layer to wrap.""" + names = ["attn1", "attn"] # names of self attention layer + assert hasattr(pipe, "transformer"), "transformer is not found in pipeline." + assert hasattr(pipe.transformer, "transformer_blocks"), "transformer_blocks is not found in pipeline." + block = pipe.transformer.transformer_blocks[0] + for name in names: + if hasattr(block, name): + return name + raise AttributeError(f"Attention layer name is not found in {names}.") + + +_FASTATTN: Optional[FastAttnState] = None + + +def get_fast_attn_state() -> FastAttnState: + # assert _FASTATTN is not None, "FastAttn state is not initialized" + return _FASTATTN + + +def get_fast_attn_enable() -> bool: + """Return whether fast attention is enabled.""" + return get_fast_attn_state().enable + + +def get_fast_attn_step() -> int: + """Return the fast attention step.""" + return get_fast_attn_state().n_step + + +def get_fast_attn_calib() -> int: + """Return the fast attention calibration.""" + return get_fast_attn_state().n_calib + + +def get_fast_attn_threshold() -> float: + """Return the fast attention threshold.""" + return get_fast_attn_state().threshold + + +def get_fast_attn_window_size() -> int: + """Return the fast attention window size.""" + return get_fast_attn_state().window_size + + +def get_fast_attn_coco_path() -> Optional[str]: + """Return the fast attention coco path.""" + return get_fast_attn_state().coco_path + + +def get_fast_attn_use_cache() -> bool: + """Return the fast attention use_cache.""" + return get_fast_attn_state().use_cache + + +def get_fast_attn_config_file() -> str: + """Return the fast attention config file.""" + return get_fast_attn_state().config_file + + +def get_fast_attn_layer_name() -> str: + """Return the fast attention layer name.""" + return get_fast_attn_state().layer_name + + +def initialize_fast_attn_state(pipeline: DiffusionPipeline, single_config: FastAttnConfig): + global _FASTATTN + if _FASTATTN is not None: + logger.warning("FastAttn state is already initialized, reinitializing with pipeline...") + _FASTATTN = FastAttnState(pipe=pipeline, config=single_config) diff --git a/xfuser/core/fast_attention/utils.py b/xfuser/core/fast_attention/utils.py new file mode 100644 index 00000000..030a80ea --- /dev/null +++ b/xfuser/core/fast_attention/utils.py @@ -0,0 +1,232 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/thu-nics/DiTFastAttn/blob/main/dit_fast_attention.py +# Copyright (c) 2024 NICS-EFC Lab of Tsinghua University. + +import torch +from xfuser.core.distributed import ( + get_dp_group, + get_data_parallel_rank, +) +from diffusers import DiffusionPipeline +from diffusers.models.transformers.transformer_2d import Transformer2DModel +from xfuser.model_executor.layers.attention_processor import xFuserAttentionBaseWrapper +from collections import Counter +import os +import json +import numpy as np + + +from .fast_attn_state import ( + get_fast_attn_step, + get_fast_attn_calib, + get_fast_attn_threshold, + get_fast_attn_coco_path, + get_fast_attn_use_cache, + get_fast_attn_config_file, + get_fast_attn_layer_name, +) + +from .attn_layer import ( + xFuserFastAttention, + FastAttnMethod, +) + +from xfuser.logger import init_logger + +logger = init_logger(__name__) + + +def save_config_file(step_methods, file_path): + folder = os.path.dirname(file_path) + if not os.path.exists(folder): + os.makedirs(folder) + format_data = { + f"block{blocki}": {f"step{stepi}": method.name for stepi, method in enumerate(methods)} + for blocki, methods in enumerate(step_methods) + } + with open(file_path, "w") as file: + json.dump(format_data, file, indent=2) + + +def load_config_file(file_path): + with open(file_path, "r") as file: + format_data = json.load(file) + steps_methods = [[FastAttnMethod[method] for method in format_method.values()] for format_method in format_data.values()] + return steps_methods + + +def compression_loss(a, b): + ls = [] + if a.__class__.__name__ == "Transformer2DModelOutput": + a = [a.sample] + b = [b.sample] + weight = torch.tensor(0.0) + for ai, bi in zip(a, b): + if isinstance(ai, torch.Tensor): + weight += ai.numel() + diff = (ai - bi) / (torch.max(ai, bi) + 1e-6) + loss = diff.abs().clip(0, 10).mean() + ls.append(loss) + weight_sum = get_dp_group().all_reduce(weight.clone().to(ai.device)) + local_loss = (weight / weight_sum) * (sum(ls) / len(ls)) + global_loss = get_dp_group().all_reduce(local_loss.clone().to(ai.device)).item() + return global_loss + + +def transformer_forward_pre_hook(m: Transformer2DModel, args, kwargs): + attn_name = get_fast_attn_layer_name() + now_stepi = getattr(m.transformer_blocks[0], attn_name).stepi + # batch_size = get_fast_attn_calib() + # dp_degree = + + for blocki, block in enumerate(m.transformer_blocks): + # Set `need_compute_residual` to False to avoid the process of trying different + # compression strategies to override the saved residual. + fast_attn = getattr(block, attn_name).processor.fast_attn + fast_attn.need_compute_residual[now_stepi] = False + fast_attn.need_cache_output = False + raw_outs = m.forward(*args, **kwargs) + for blocki, block in enumerate(m.transformer_blocks): + if now_stepi == 0: + continue + fast_attn = getattr(block, attn_name).processor.fast_attn + method_candidates = [ + FastAttnMethod.OUTPUT_SHARE, + FastAttnMethod.RESIDUAL_WINDOW_ATTN_CFG_SHARE, + FastAttnMethod.RESIDUAL_WINDOW_ATTN, + FastAttnMethod.FULL_ATTN_CFG_SHARE, + ] + selected_method = FastAttnMethod.FULL_ATTN + for method in method_candidates: + # Try compress this attention using `method` + fast_attn.steps_method[now_stepi] = method + + # Set the timestep index of every layer back to now_stepi + # (which are increased by one in every forward) + for _block in m.transformer_blocks: + for layer in _block.children(): + if isinstance(layer, xFuserAttentionBaseWrapper): + layer.stepi = now_stepi + + # Compute the overall transformer output + outs = m.forward(*args, **kwargs) + + loss = compression_loss(raw_outs, outs) + threshold = m.loss_thresholds[now_stepi][blocki] + + if loss < threshold: + selected_method = method + break + + fast_attn.steps_method[now_stepi] = selected_method + del loss, outs + del raw_outs + + # Set the timestep index of every layer back to now_stepi + # (which are increased by one in every forward) + for _block in m.transformer_blocks: + for layer in _block.children(): + if isinstance(layer, xFuserAttentionBaseWrapper): + layer.stepi = now_stepi + + for blocki, block in enumerate(m.transformer_blocks): + # During the compression plan decision process, + # we set the `need_compute_residual` property of all attention modules to `True`, + # so that all full attention modules will save its residual for convenience. + # The residual will be saved in the follow-up forward call. + fast_attn = getattr(block, attn_name).processor.fast_attn + fast_attn.need_compute_residual[now_stepi] = True + fast_attn.need_cache_output = True + + +def select_methods(pipe: DiffusionPipeline): + blocks = pipe.transformer.transformer_blocks + transformer: Transformer2DModel = pipe.transformer + attn_name = get_fast_attn_layer_name() + n_steps = get_fast_attn_step() + # reset all processors + for block in blocks: + fast_attn: xFuserFastAttention = getattr(block, attn_name).processor.fast_attn + fast_attn.set_methods( + [FastAttnMethod.FULL_ATTN] * n_steps, + selecting=True, + ) + + # Setup loss threshold for each timestep and layer + loss_thresholds = [] + for step_i in range(n_steps): + sub_list = [] + for blocki in range(len(blocks)): + threshold_i = (blocki + 1) / len(blocks) * get_fast_attn_threshold() + sub_list.append(threshold_i) + loss_thresholds.append(sub_list) + + # calibration + hook = transformer.register_forward_pre_hook(transformer_forward_pre_hook, with_kwargs=True) + transformer.loss_thresholds = loss_thresholds + + seed = 3 + guidance_scale = 4.5 + if not os.path.exists(get_fast_attn_coco_path()): + raise FileNotFoundError(f"File {get_fast_attn_coco_path()} not found") + with open(get_fast_attn_coco_path(), "r") as file: + mscoco_anno = json.load(file) + np.random.seed(seed) + slice_ = np.random.choice(mscoco_anno["annotations"], get_fast_attn_calib()) + calib_x = [d["caption"] for d in slice_] + pipe( + prompt=calib_x, + num_inference_steps=n_steps, + generator=torch.manual_seed(seed), + output_type="latent", + negative_prompt="", + return_dict=False, + guidance_scale=guidance_scale, + ) + + hook.remove() + del transformer.loss_thresholds + + blocks_methods = [getattr(block, attn_name).processor.fast_attn.steps_method for block in blocks] + return blocks_methods + + +def set_methods( + pipe: DiffusionPipeline, + blocks_methods: list, +): + attn_name = get_fast_attn_layer_name() + blocks = pipe.transformer.transformer_blocks + for blocki, block in enumerate(blocks): + getattr(block, attn_name).processor.fast_attn.set_methods(blocks_methods[blocki]) + + +def statistics(pipe: DiffusionPipeline): + attn_name = get_fast_attn_layer_name() + blocks = pipe.transformer.transformer_blocks + counts = Counter([method for block in blocks for method in getattr(block, attn_name).processor.fast_attn.steps_method]) + total = sum(counts.values()) + for k, v in counts.items(): + logger.info(f"{attn_name} {k} {v/total}") + + +def fast_attention_compression(pipe: DiffusionPipeline): + config_file = get_fast_attn_config_file() + logger.info(f"config file is {config_file}") + + if get_fast_attn_use_cache() and os.path.exists(config_file): + logger.info(f"load config file {config_file} as DiTFastAttn compression methods.") + blocks_methods = load_config_file(config_file) + else: + if get_fast_attn_use_cache(): + logger.warning(f"config file {config_file} not found.") + logger.info("start to select DiTFastAttn compression methods.") + blocks_methods = select_methods(pipe) + if get_data_parallel_rank() == 0: + save_config_file(blocks_methods, config_file) + logger.info(f"save DiTFastAttn compression methods to {config_file}") + + set_methods(pipe, blocks_methods) + + statistics(pipe) diff --git a/xfuser/model_executor/base_wrapper.py b/xfuser/model_executor/base_wrapper.py index 725a227c..1a04e9a6 100644 --- a/xfuser/model_executor/base_wrapper.py +++ b/xfuser/model_executor/base_wrapper.py @@ -9,6 +9,7 @@ get_tensor_model_parallel_world_size, ) from xfuser.core.distributed.runtime_state import get_runtime_state +from xfuser.core.fast_attention import get_fast_attn_enable class xFuserBaseWrapper(metaclass=ABCMeta): @@ -40,6 +41,7 @@ def check_condition_fn(self, *args, **kwargs): and get_classifier_free_guidance_world_size() == 1 and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 + and get_fast_attn_enable() == False ): return func(self, *args, **kwargs) if not get_runtime_state().is_ready(): diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index cb08f953..964ac632 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -24,6 +24,10 @@ get_sequence_parallel_rank, get_sp_group, ) +from xfuser.core.fast_attention import ( + xFuserFastAttention, + get_fast_attn_enable, +) from xfuser.core.cache_manager.cache_manager import get_cache_manager from xfuser.core.distributed.runtime_state import get_runtime_state @@ -246,6 +250,9 @@ def __init__(self): else: self.hybrid_seq_parallel_attn = None + if get_fast_attn_enable(): + self.fast_attn = xFuserFastAttention() + def __call__( self, attn: Attention, @@ -261,6 +268,11 @@ def __call__( deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) + #! ---------------------------------------- Fast Attention ---------------------------------------- + if get_fast_attn_enable(): + return self.fast_attn(attn, hidden_states, encoder_hidden_states, attention_mask, temb, *args, **kwargs) + #! ---------------------------------------- Fast Attention ---------------------------------------- + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -435,6 +447,9 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache, ) + if get_fast_attn_enable(): + self.fast_attn = xFuserFastAttention() + def __call__( self, attn: Attention, diff --git a/xfuser/model_executor/models/transformers/base_transformer.py b/xfuser/model_executor/models/transformers/base_transformer.py index 020c463a..d3a14753 100644 --- a/xfuser/model_executor/models/transformers/base_transformer.py +++ b/xfuser/model_executor/models/transformers/base_transformer.py @@ -11,6 +11,7 @@ get_sequence_parallel_world_size, get_tensor_model_parallel_world_size, ) +from xfuser.core.fast_attention import get_fast_attn_enable from xfuser.core.distributed.runtime_state import get_runtime_state from xfuser.logger import init_logger from xfuser.model_executor.models import xFuserModelBaseWrapper @@ -46,6 +47,7 @@ def _convert_transformer_for_parallel( get_pipeline_parallel_world_size() == 1 and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 + and get_fast_attn_enable() == False ): return transformer else: diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 18a2e204..34c4310c 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -30,6 +30,11 @@ is_dp_last_group, get_sequence_parallel_rank, ) +from xfuser.core.fast_attention import ( + get_fast_attn_enable, + initialize_fast_attn_state, + fast_attention_compression, +) from xfuser.model_executor.base_wrapper import xFuserBaseWrapper from xfuser.envs import PACKAGES_CHECKER @@ -38,6 +43,7 @@ from xfuser.model_executor.schedulers import * from xfuser.model_executor.models.transformers import * +from xfuser.model_executor.layers.attention_processor import * try: import os @@ -61,6 +67,7 @@ def __init__( ): self.module: DiffusionPipeline self._init_runtime_state(pipeline=pipeline, engine_config=engine_config) + self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config) # backbone transformer = getattr(pipeline, "transformer", None) @@ -109,6 +116,30 @@ def to(self, *args, **kwargs): self.module = self.module.to(*args, **kwargs) return self + @staticmethod + def enable_fast_attn(func): + @wraps(func) + def fast_attn_fn(self, *args, **kwargs): + if get_fast_attn_enable(): + for block in self.module.transformer.transformer_blocks: + for layer in block.children(): + if isinstance(layer, xFuserAttentionBaseWrapper): + layer.stepi = 0 + layer.cached_residual = None + layer.cached_output = None + out = func(self, *args, **kwargs) + for block in self.module.transformer.transformer_blocks: + for layer in block.children(): + if isinstance(layer, xFuserAttentionBaseWrapper): + layer.stepi = 0 + layer.cached_residual = None + layer.cached_output = None + return out + else: + return func(self, *args, **kwargs) + + return fast_attn_fn + @staticmethod def enable_data_parallel(func): @wraps(func) @@ -145,6 +176,7 @@ def check_naive_forward_fn(self, *args, **kwargs): and get_classifier_free_guidance_world_size() == 1 and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 + and get_fast_attn_enable() == False ): return self.module(*args, **kwargs) else: @@ -192,6 +224,10 @@ def forward(self): def prepare_run( self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1 ): + if get_fast_attn_enable(): + # set compression methods for DiTFastAttn + fast_attention_compression(self) + prompt = [""] * input_config.batch_size if input_config.batch_size > 1 else "" warmup_steps = get_runtime_state().runtime_config.warmup_steps get_runtime_state().runtime_config.warmup_steps = sync_steps @@ -228,6 +264,11 @@ def _init_runtime_state( ): initialize_runtime_state(pipeline=pipeline, engine_config=engine_config) + def _init_fast_attn_state( + self, pipeline: DiffusionPipeline, engine_config: EngineConfig + ): + initialize_fast_attn_state(pipeline=pipeline, single_config=engine_config.fast_attn_config) + def _convert_transformer_backbone( self, transformer: nn.Module, enable_torch_compile: bool, enable_onediff: bool ): @@ -236,6 +277,7 @@ def _convert_transformer_backbone( and get_sequence_parallel_world_size() == 1 and get_classifier_free_guidance_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 + and get_fast_attn_enable() == False ): logger.info( "Transformer backbone found, but model parallelism is not enabled, " diff --git a/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py b/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py index f68faa94..9355d592 100644 --- a/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py +++ b/xfuser/model_executor/pipelines/pipeline_pixart_alpha.py @@ -47,6 +47,7 @@ def from_pretrained( return cls(pipeline, engine_config) @torch.no_grad() + @xFuserPipelineBaseWrapper.enable_fast_attn @xFuserPipelineBaseWrapper.enable_data_parallel @xFuserPipelineBaseWrapper.check_to_use_naive_forward def __call__( diff --git a/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py b/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py index c306e86f..a864124b 100644 --- a/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py +++ b/xfuser/model_executor/pipelines/pipeline_pixart_sigma.py @@ -47,6 +47,7 @@ def from_pretrained( return cls(pipeline, engine_config) @torch.no_grad() + @xFuserPipelineBaseWrapper.enable_fast_attn @xFuserPipelineBaseWrapper.enable_data_parallel @xFuserPipelineBaseWrapper.check_to_use_naive_forward def __call__(