Skip to content

Commit

Permalink
feat: DiTFastAttn for PixArt (xdit-project#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZDJeffrey authored and feifeibear committed Oct 25, 2024
1 parent c84d8b7 commit 25d5ede
Show file tree
Hide file tree
Showing 14 changed files with 762 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ profile/
xfuser.egg-info/
dist/*
latte_output.mp4
*.sh
*.sh
cache/
21 changes: 20 additions & 1 deletion examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export PYTHONPATH=$PWD:$PYTHONPATH
# or you can simply use the model's ID on Hugging Face,
# which will then be downloaded to the default cache path on Hugging Face.

export MODEL_TYPE="CogVideoX"
export MODEL_TYPE="Pixart-alpha"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
Expand Down Expand Up @@ -53,24 +53,42 @@ if [ "$MODEL_TYPE" = "Flux" ]; then
N_GPUS=8
PARALLEL_ARGS="--ulysses_degree $N_GPUS"
CFG_ARGS=""
FAST_ATTN_ARGS=""

# CogVideoX asserts sp_degree == ulysses_degree*ring_degree <= 2. Also, do not set the pipefusion degree.
elif [ "$MODEL_TYPE" = "CogVideoX" ]; then
N_GPUS=4
PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
FAST_ATTN_ARGS=""

# HunyuanDiT asserts sp_degree == ulysses_degree*ring_degree <= 2, or the output will be incorrect.
elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
FAST_ATTN_ARGS=""

# Pixart-alpha can use DiTFastAttn to compression attention module, but DiTFastAttn can only use with data parallel
elif [ "$MODEL_TYPE" = "Pixart-alpha" ]; then
N_GPUS=4
PARALLEL_ARGS="--data_parallel_degree $N_GPUS"
CFG_ARGS=""
FAST_ATTN_ARGS="--use_fast_attn --window_size 512 --n_calib 4 --threshold 0.15 --use_cache --coco_path /data/mscoco/annotations/captions_val2014.json"

# Pixart-sigma can use DiTFastAttn to compression attention module, but DiTFastAttn can only use with data parallel
elif [ "$MODEL_TYPE" = "Pixart-sigma" ]; then
N_GPUS=4
PARALLEL_ARGS="--data_parallel_degree $N_GPUS"
CFG_ARGS=""
FAST_ATTN_ARGS="--use_fast_attn --window_size 512 --n_calib 4 --threshold 0.15 --use_cache --coco_path /data/mscoco/annotations/captions_val2014.json"

else
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
FAST_ATTN_ARGS=""
fi


Expand All @@ -95,5 +113,6 @@ $OUTPUT_ARGS \
--warmup_steps 0 \
--prompt "A small dog" \
$CFG_ARGS \
$FAST_ATTN_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG
56 changes: 56 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xfuser.core.distributed import init_distributed_environment
from xfuser.config.config import (
EngineConfig,
FastAttnConfig,
ParallelConfig,
TensorParallelConfig,
PipeFusionParallelConfig,
Expand Down Expand Up @@ -94,6 +95,13 @@ class xFuserArgs:
seed: int = 42
output_type: str = "pil"
enable_sequential_cpu_offload: bool = False
# DiTFastAttn arguments
use_fast_attn: bool = False
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser):
Expand Down Expand Up @@ -240,6 +248,43 @@ def add_cli_args(parser: FlexibleArgumentParser):
help="Offloading the weights to the CPU.",
)

# DiTFastAttn arguments
fast_attn_group = parser.add_argument_group("DiTFastAttn Options")
fast_attn_group.add_argument(
"--use_fast_attn",
action="store_true",
help="Use DiTFastAttn to accelerate single inference. Only data parallelism can be used with DITFastAttn.",
)
fast_attn_group.add_argument(
"--n_calib",
type=int,
default=8,
help="Number of prompts for compression method seletion.",
)
fast_attn_group.add_argument(
"--threshold",
type=float,
default=0.5,
help="Threshold for selecting attention compression method.",
)
fast_attn_group.add_argument(
"--window_size",
type=int,
default=64,
help="Size of window attention.",
)
fast_attn_group.add_argument(
"--coco_path",
type=str,
default=None,
help="Path of MS COCO annotation json file.",
)
fast_attn_group.add_argument(
"--use_cache",
action="store_true",
help="Use cache config for attention compression.",
)

return parser

@classmethod
Expand Down Expand Up @@ -294,10 +339,21 @@ def create_config(
),
)

fast_attn_config = FastAttnConfig(
use_fast_attn=self.use_fast_attn,
n_step=self.num_inference_steps,
n_calib=self.n_calib,
threshold=self.threshold,
window_size=self.window_size,
coco_path=self.coco_path,
use_cache=self.use_cache,
)

engine_config = EngineConfig(
model_config=model_config,
runtime_config=runtime_config,
parallel_config=parallel_config,
fast_attn_config=fast_attn_config,
)

input_config = InputConfig(
Expand Down
21 changes: 21 additions & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ def __post_init__(self):
check_env()


@dataclass
class FastAttnConfig:
use_fast_attn: bool = False
n_step: int = 20
n_calib: int = 8
threshold: float = 0.5
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False

def __post_init__(self):
assert self.n_calib > 0, "n_calib must be greater than 0"
assert self.threshold > 0.0, "threshold must be greater than 0"


@dataclass
class DataParallelConfig:
dp_degree: int = 1
Expand Down Expand Up @@ -217,6 +232,12 @@ class EngineConfig:
model_config: ModelConfig
runtime_config: RuntimeConfig
parallel_config: ParallelConfig
fast_attn_config: FastAttnConfig

def __post_init__(self):
world_size = dist.get_world_size()
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs."""
Expand Down
37 changes: 37 additions & 0 deletions xfuser/core/fast_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .fast_attn_state import (
get_fast_attn_state,
get_fast_attn_enable,
get_fast_attn_step,
get_fast_attn_calib,
get_fast_attn_threshold,
get_fast_attn_window_size,
get_fast_attn_coco_path,
get_fast_attn_use_cache,
get_fast_attn_config_file,
get_fast_attn_layer_name,
initialize_fast_attn_state,
)

from .attn_layer import (
FastAttnMethod,
xFuserFastAttention,
)

from .utils import fast_attention_compression

__all__ = [
"get_fast_attn_state",
"get_fast_attn_enable",
"get_fast_attn_step",
"get_fast_attn_calib",
"get_fast_attn_threshold",
"get_fast_attn_window_size",
"get_fast_attn_coco_path",
"get_fast_attn_use_cache",
"get_fast_attn_config_file",
"get_fast_attn_layer_name",
"initialize_fast_attn_state",
"xFuserFastAttention",
"FastAttnMethod",
"fast_attention_compression",
]
Loading

0 comments on commit 25d5ede

Please sign in to comment.