-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix num_heads in _upad_input #26490
Fix num_heads in _upad_input #26490
Conversation
The variable num_key_value_heads has falsely been named num_heads, which led to reshaping the query_layer using the wrong attention head count. (It would have been enough to use the correct variable self.num_heads instead of num_heads, but I renamed num_heads to num_key_value_heads for clarity)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me thanks for spotting the bug! Can confirm all tests pass on my end
I think we never catched it because kv had always the same number of heads than the query.
Can you propagate these changes across other models ?
make fix-copies
@younesbelkada Just did 👍 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Falcon and llama tests pass, thanks for fixing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! If you have time adding a small test might be great! Aftertest_flash_attn_2_generate_padding_right
we can test test_dummy_flash_attn
with a small model that uses GQA! If you don;'t have time I'm leaving this as a TODO as the 70B model will need this
I do not have time at the moment. Leave it as a TODO and I might pick up the task in the coming weeks. |
What does this PR do?
The variable num_key_value_heads in the FlashAttention Module has falsely been named num_heads, which led to reshaping the query_layer using the wrong attention head count. (It would have been enough to use the correct variable self.num_heads instead of num_heads, but I renamed num_heads to num_key_value_heads for clarity)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker and @younesbelkada