Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Llama: update rope scaling to match static cache changes #29143

Merged
merged 2 commits into from
Feb 21, 2024

Conversation

gante
Copy link
Member

@gante gante commented Feb 20, 2024

What does this PR do?

(see title :))

What's breaking? The shape of the returned sin/cos caches are changed (sin/cos for all positions -> sin/cos for the positions in position_ids). Note that this breaking change was also present in the static cache PR, for the main RoPE class (#27931).

Review suggestion:

  1. Review changes in Llama
  2. Review the rest

@@ -362,7 +362,6 @@ def test_save_load_fast_init_from_base(self):
pass

@parameterized.expand([("linear",), ("dynamic",)])
@unittest.skip("TODO @gante fix this for Llama")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was fixed as a result of the changes in this PR :)

@gante gante requested a review from ArthurZucker February 20, 2024 14:29
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧼 nice cleanup!
Main concern: BC, let's keep the cos_cache and sin_cache for 1 release and then we can directly open a PR on main to remove it!

Comment on lines -149 to +143
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, position_ids, seq_len=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am alright with this but it is breaking for any libs that rely on sin cached and cos cached. Same for the static cache PR!
Let's just add a mention that it will be removed next release and still compute cos and sin!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the cool part -- it calls super's forward, which in turn caches sin/cos (see here). BC is preserved 🙌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but we need a warning to deprecate !
Follow up is fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow -- the warning is here. Or were you thinking of some other warning?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Had not seen this when I checked the diff

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
cos, sin = super().forward(x, position_ids, seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot cleaner!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada also pointed out that the shape of the output of the rope layer is different from before. Thus this is a bit breaking. If so, let's add a big 🔴 on the PR to make sure we know that there are breaking changes!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests all pass on PEFT end ! Thanks for the notice 💪

@gante gante changed the title Llama: update rope scaling to match static cache changes 🚨 Llama: update rope scaling to match static cache changes Feb 21, 2024
@gante gante merged commit 3994fa5 into huggingface:main Feb 21, 2024
19 checks passed
@gante gante deleted the update_rope_scaling branch February 21, 2024 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants