Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove dependency on flash_attn #410

Merged
merged 3 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def get_cuda_version():
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang>=0.3.0",
"yunchang>=0.6.0",
"pytest",
"flask",
"opencv-python",
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.6.3",
"ray"
],
extras_require={
"diffusers": [
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"flash_attn>=2.6.3",
]
},
url="https://github.com/xdit-project/xDiT.",
Expand Down
4 changes: 0 additions & 4 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def __post_init__(self):
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
if not HAS_FLASH_ATTN and self.ring_degree > 1:
raise ValueError(
f"Flash attention not found. Ring attention not available. Please set ring_degree to 1"
)


@dataclass
Expand Down
8 changes: 7 additions & 1 deletion xfuser/core/fast_attention/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from diffusers.models.attention_processor import Attention
from typing import Optional
import torch.nn.functional as F
import flash_attn

try:
import flash_attn
except ImportError:
flash_attn = None

from enum import Flag, auto
from .fast_attn_state import get_fast_attn_window_size

Expand Down Expand Up @@ -165,6 +170,7 @@ def __call__(
is_causal=False,
).transpose(1, 2)
elif method.has(FastAttnMethod.FULL_ATTN):
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
all_hidden_states = flash_attn.flash_attn_func(query, key, value)
if need_compute_residual:
# Compute the full-window attention residual
Expand Down
2 changes: 0 additions & 2 deletions xfuser/core/long_ctx_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .hybrid import xFuserLongContextAttention
from .ulysses import xFuserUlyssesAttention

__all__ = [
"xFuserLongContextAttention",
"xFuserUlyssesAttention",
]
7 changes: 7 additions & 0 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import torch.distributed
from yunchang import LongContextAttention
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")

from yunchang.comm.all_to_all import SeqAllToAll4D

from xfuser.logger import init_logger
Expand All @@ -21,6 +26,7 @@ def __init__(
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: AttnType = AttnType.FA,
) -> None:
"""
Arguments:
Expand All @@ -35,6 +41,7 @@ def __init__(
gather_idx=gather_idx,
ring_impl_type=ring_impl_type,
use_pack_qkv=use_pack_qkv,
attn_type = attn_type,
)
self.use_kv_cache = use_kv_cache
if (
Expand Down
10 changes: 8 additions & 2 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward

from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc

try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None

def xdit_ring_flash_attn_forward(
process_group,
Expand Down Expand Up @@ -80,6 +85,7 @@ def xdit_ring_flash_attn_forward(
key, value = k, v

if not causal or step <= comm.rank:
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
Expand Down
5 changes: 0 additions & 5 deletions xfuser/core/long_ctx_attention/ulysses/__init__.py

This file was deleted.

168 changes: 0 additions & 168 deletions xfuser/core/long_ctx_attention/ulysses/attn_layer.py

This file was deleted.

Loading
Loading