Skip to content

Commit

Permalink
remove dependency on flash_attn (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Dec 26, 2024
1 parent 167351a commit 92187b8
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 213 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def get_cuda_version():
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang>=0.3.0",
"yunchang>=0.6.0",
"pytest",
"flask",
"opencv-python",
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.6.3",
"ray"
],
extras_require={
"diffusers": [
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"flash_attn>=2.6.3",
]
},
url="https://github.com/xdit-project/xDiT.",
Expand Down
4 changes: 0 additions & 4 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def __post_init__(self):
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
if not HAS_FLASH_ATTN and self.ring_degree > 1:
raise ValueError(
f"Flash attention not found. Ring attention not available. Please set ring_degree to 1"
)


@dataclass
Expand Down
8 changes: 7 additions & 1 deletion xfuser/core/fast_attention/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from diffusers.models.attention_processor import Attention
from typing import Optional
import torch.nn.functional as F
import flash_attn

try:
import flash_attn
except ImportError:
flash_attn = None

from enum import Flag, auto
from .fast_attn_state import get_fast_attn_window_size

Expand Down Expand Up @@ -165,6 +170,7 @@ def __call__(
is_causal=False,
).transpose(1, 2)
elif method.has(FastAttnMethod.FULL_ATTN):
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
all_hidden_states = flash_attn.flash_attn_func(query, key, value)
if need_compute_residual:
# Compute the full-window attention residual
Expand Down
2 changes: 0 additions & 2 deletions xfuser/core/long_ctx_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .hybrid import xFuserLongContextAttention
from .ulysses import xFuserUlyssesAttention

__all__ = [
"xFuserLongContextAttention",
"xFuserUlyssesAttention",
]
7 changes: 7 additions & 0 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import torch.distributed
from yunchang import LongContextAttention
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")

from yunchang.comm.all_to_all import SeqAllToAll4D

from xfuser.logger import init_logger
Expand All @@ -21,6 +26,7 @@ def __init__(
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: AttnType = AttnType.FA,
) -> None:
"""
Arguments:
Expand All @@ -35,6 +41,7 @@ def __init__(
gather_idx=gather_idx,
ring_impl_type=ring_impl_type,
use_pack_qkv=use_pack_qkv,
attn_type = attn_type,
)
self.use_kv_cache = use_kv_cache
if (
Expand Down
10 changes: 8 additions & 2 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward

from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc

try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None

def xdit_ring_flash_attn_forward(
process_group,
Expand Down Expand Up @@ -80,6 +85,7 @@ def xdit_ring_flash_attn_forward(
key, value = k, v

if not causal or step <= comm.rank:
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
Expand Down
5 changes: 0 additions & 5 deletions xfuser/core/long_ctx_attention/ulysses/__init__.py

This file was deleted.

168 changes: 0 additions & 168 deletions xfuser/core/long_ctx_attention/ulysses/attn_layer.py

This file was deleted.

Loading

0 comments on commit 92187b8

Please sign in to comment.