-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash_attn ImportError breaking model loading (Florence-2-base-ft) #31793
Comments
Hello @Laz4rz, thanks for your report! The Florence-2 model is hosted on the Hub, with code contributed by the authors. It doesn't live within the Could you put the entirety of your stack trace? I'm not sure I see where the code is failing, and I see in their code that they seem to protect the import of flash attention as well, so I'm curious to see what's happening. |
If flash_attn is installed:
If it isn't:
|
huh it seems like it entered the conditional statement |
Ok that was a good direction, seems that Commenting out both checks for flash_attn allows to load the model and run inference correctly. Give me a few minutes and I'll try to pinpoint what exactly is wrong. Probably only check if the lib is installed, but not if it can be imported. If we add this check it should work as expected -- but still should yield some information so that users know flash_attn is not being used due to import error. |
Welp, looks like I was a little too quick -- got so many conda envs that I didnt see that I commented the import protection in dynamic_module_utils.py line 181 yesterday. So that' This however doesn't change the final outcome. By proper handling of the flash_attn import Florence can be used and yields same result as with flash_attn installed. Changing the behavior of So there are two different things:
I'm not sure how to handle 1. since that's probably on Microsoft's side if they want to allow running it with or without flash_attn. But I think 2. could benefit from catching and handling all libraries in the same way, so either: a). We add check for being able to import for all libraries, not only Torch as it is currently I guess a). would be preferred. The PR for this is: #31798 |
Hi @Laz4rz, thanks for raising this issue and opening a PR to address. We definitely don't want to do a) as this will massively increase the time it takes to load the library. It's not obvious to be we'd want to do b) either. Regarding @LysandreJik's question above, what do you get when running |
Hey, thanks for taking a look @amyeroberts.
Importing flash_attn yields: Dao-AILab/flash-attention#1027, and from what I've read it is a common flash_attn problem. Upgrading to Python 3.11 fixes the issue, also Colab which is running on 3.10 (with same libraries versions) has no problem with flash_attn. Tried it on few different VMs. |
I think the name The linked issue looks like a problem with the cuda install, and it's compatibility with the environment. In other issues which report the same think - it appears to do with the pytorch installed e.g. Dao-AILab/flash-attention#836 and Dao-AILab/flash-attention#919. This would explain why python 3.10 works in some cases and not others. |
The flash_attn is an unnecessary import within the modelling file of florence-2. You can remove it like this:
This should solve the issue |
Yeah thanks, that's what I saw too. Im not sure why it's needed. |
cc @Rocketknight1 regarding the is_flash_attn available check for hub checkpoint |
Hey all, I think this should now be fixed following this PR. However, you may need to install from Closing this issue because I think it's solved, but if you're still having issues after installing from |
this worked for me but after deleting the flash_atten files manually from my pip as when i tried to uninstall it showed package not available, so deleted manually then worked. |
System Info
Transformers
.from_pretrained()
may break while loading models, in my case models from Florence-2 family. It yields:The above is caused by a failed flash_attn import. Which to be honest is strange, because up to this point using flash_attn was not necessary -- dunno if its needed now, but unless flash_attn can be imported, the model will not load.
This happens for the following combination of python, transformers, torch and flash_attn:
It can be fixed by upgrading python to 3.11 or 3.12.
Who can help?
@amyeroberts
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Model loading properly, without the need to import flash_attn
The text was updated successfully, but these errors were encountered: