Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom head_dim support to Llama #32502

Closed
wants to merge 12 commits into from

Conversation

suhara
Copy link
Contributor

@suhara suhara commented Aug 7, 2024

What does this PR do?

Llama assumes that head_dim * num_heads == hidden_size and does not accommodate any models with custom head_dim size. This PR relaxes the assumption and makes Llama use custom head_dim sizes.

This PR has a dependency on the following PR:

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Comment on lines 353 to 358
if config.head_dim is None:
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.head_dim is None:
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
if config.head_dim is None and (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts Thanks for the suggestion! Updated the if block accordingly.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@suhara suhara marked this pull request as ready for review August 7, 2024 22:09
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! not sure we needs this (what I meant by a regression is that I thought we did allow head dim for llama) other models had this constraint lifted liike gemma I think

@suhara
Copy link
Contributor Author

suhara commented Aug 8, 2024

Hi @ArthurZucker

The motivation is that some custom Llama-architecture based models with custom head_dim sizes cannot be loaded by LlamaModel due to this constraint. (Some context: NVIDIA/NeMo#10078)

That said, I understand your concern on the regression issue. What would be your suggestion? If the existing Llama class is supposed to support official Llama models, creating a new class to cover custom Llama-based variant models would be an option?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions should fix the CI let's go with this.
Sorry for the delayed reviewed I was OOO for a bit

@@ -187,6 +191,7 @@ def __init__(
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what you can do here is if head_dim is None: self.head_dim = self.hidden_size // self.num_heads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Added.

@ArthurZucker
Copy link
Collaborator

Cool can you just run make fixup and make fix-copies

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker mentioned this pull request Aug 16, 2024
@suhara
Copy link
Contributor Author

suhara commented Aug 16, 2024

@ArthurZucker It seems that you created another PR and fix the remaining issues for this. Thank you!

@suhara
Copy link
Contributor Author

suhara commented Aug 17, 2024

Hi @ArthurZucker

I saw @Qubitium 's message in #32857. The newly created PR is missing the fix for o_proj.
I fixed the CI issues nd this PR should be ready to merge. Can you check?

Thanks!

@suhara suhara force-pushed the suhara/llama-kv-channels branch from e0af552 to 06cc89d Compare August 18, 2024 18:55
@suhara
Copy link
Contributor Author

suhara commented Aug 19, 2024

Hi @ArthurZucker

The CI passed. Can you merge this PR (or #32857 after fixing the issue)? Thanks!

@suhara
Copy link
Contributor Author

suhara commented Aug 20, 2024

#32857 has been merged. Close this PR.

@suhara suhara closed this Aug 20, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 14, 2024
This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

[ghstack-poisoned]
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 14, 2024
This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

ghstack-source-id: 253658231
Pull Request resolved: #6872
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)
ghstack-source-id: 254171929
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502
ghstack-source-id: 254176606

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 18, 2024
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502
ghstack-source-id: 254190233

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 25, 2024
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502
ghstack-source-id: 255340016

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 25, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 25, 2024
helunwencser added a commit to pytorch/executorch that referenced this pull request Nov 25, 2024
Pull Request resolved: #6872

This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/).

Similar change in HF: huggingface/transformers#32502
ghstack-source-id: 255340016

Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/)

Co-authored-by: Lunwen He <lwhecser@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants