Skip to content

Commit

Permalink
add warning for old flash attn (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi authored Dec 18, 2024
1 parent b5fb784 commit 46c0d54
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_cuda_version():
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.7.0" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops
"flash_attn>=2.6.3" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops
],
extras_require={
"diffusers": [
Expand Down
8 changes: 8 additions & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABCMeta, abstractmethod
from functools import wraps
from packaging import version
from typing import Callable, Dict, List, Optional, Tuple, Union
import sys
import torch
import torch.distributed
import torch.nn as nn
Expand Down Expand Up @@ -299,6 +301,12 @@ def _convert_transformer_backbone(
if enable_torch_compile or enable_onediff:
if getattr(transformer, "forward") is not None:
if enable_torch_compile:
if "flash_attn" in sys.modules:
import flash_attn
if version.parse(flash_attn.__version__) < version.parse("2.7.0") or version.parse(torch.__version__) < version.parse("2.4.0"):
logger.warning(
"flash-attn or torch version is too old, performance with torch.compile may be suboptimal due to too many graph breaks"
)
optimized_transformer_forward = torch.compile(
getattr(transformer, "forward")
)
Expand Down

0 comments on commit 46c0d54

Please sign in to comment.