-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache updating when use_cache = False #32843
Comments
cc @gante too for the cache |
Related, but not sure if this should be a separate issue. The problem is actually slightly more general than what you've described. For example, the This has the following consequences:
|
Also FYI, it appears this issue has existed since the new cache structure was introduced in |
Hi @ciaran-regan-ie (and @nickfraser )👋 Thank you for opening the issue and elaborating on the problem! Before taking your comments and projects into consideration, let me share my view (and the context behind some changes in Our code, therefore, gravitated to its current state where we check it in the core modeling class and create a new cache instance if needed. From that point onwards, we assume that From what I'm reading in your comments, my assumption may be incorrect!
|
Hi @gante, Thanks for the detailed reply. Also, please feel free to tell me to open a new issue if that is more appropriate. I understand the new behaviour. In my case, I was calling sub-layers of a Llama-based model directly (with I find this behaviour to be quite unintuitive, but I accept that this is a niche use-case. |
@nickfraser If I understand correctly, you were expecting Shifting the cache instantiation from the inner-most block (prior to Where's what I'm thinking to do:
WDYT? It should make things much cleaner from a user perspective, while being manageable on our end 🤗 (cc @ArthurZucker ) |
Yes, exactly.
Makes sense.
Your suggestion makes a lot of sense to me - sounds great! Thanks for being so amenable too! <3 |
@gante Thank you so much! |
System Info
transformers
version: 4.44.0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm experimenting with shuffling layers in a pre-trained model. The
layer_idx
inside the Attention object makes this difficult as described in this issue. To work around this, I'm settinguse_cache = False
, however, even withuse_cache = False
, an error is occurring aspast_key_value.update
is being called in the Attention forward pass. A simple solution would be to useuse_cache
in the forward pass by adding the followingand
logic:Here is my code to reproduce. The first run through will run because the layers have not switched, but the second run will fail as the cache attempts to update.
Expected behavior
When
use_cache = False
, the cache should not be updating, right?Happy to help with PRs if you feel its necessary!
The text was updated successfully, but these errors were encountered: