-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add YaRN and Dynamic-YaRN RoPE Scaling Methods #30910
Conversation
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
cc @ArthurZucker too |
Hey! Thanks a lot for taking the time to implement this! 🤗 This is unrelated to Llama so we might need some modularity for this! |
🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Olá Miguel e Miguel! 👋 (@miguelm-almeida @mig-mfreitas )
I have a couple of requests regarding user experience and recent changes in our repo. My number 1 suggestion would be to delete the diff in all models except Llama
, and leave the model copies for another PR. It's much faster for everyone (you and me) to iterate over a model, and then copy the design when we're happy 🤗 In this particular case (RoPE models), we also have different implementations that we need to iron out on our end before adding Yarn there.
Llama
, and not on Falcon
. Some of the suggested changes only work on architectures that are up to date, like Llama
(and unlike Falcon
)
Finally: one of the goals of this PR should be to be able to load the original YaRN models using transformers
. Currently, there are some models on the Hub that have custom code (e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k). At the moment, these models require adding trust_remote_code=True
to from_pretrained
(which loads the custom code in the repo). With this PR, we remove the need for that flag and would be using the code in transformers
instead :)
Ping me if you have further questions (and feel free to ping me by email if I'm taking to long to reply) 🤗
@@ -66,13 +66,31 @@ class OpenLlamaConfig(PretrainedConfig): | |||
rope_theta (`float`, *optional*, defaults to 10000.0): | |||
The base period of the RoPE embeddings. | |||
rope_scaling (`Dict`, *optional*): | |||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling | |||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is | |||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
models on the deprecated
folder should not be updated :) (let's remove the changes on open_llama
)
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update | ||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how | ||
these scaling strategies behave: | ||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an | ||
experimental feature, subject to breaking API changes in future versions. | ||
yarn_rope_scaling (`Dict`, *optional*, defaults to `{'original_max_position_embeddings': 2048, 'extrapolation_factor': 1.0, 'attention_factor': 1.0, 'beta_fast': 32.0, 'beta_slow': 1.0, 'finetuned': False}`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments:
- The contents of
yarn_rope_scaling
should be part ofrope_scaling
A single config dict for everything related to RoPE scaling is preferable so we can easily upgrade it to a standalone config class in a future PR :)
It would also allow loading existing models by the original authors, e.g. https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k (have a look at their custom config code and the model's config file -- both assume all rope scaling params are in rope_scaling
)
- We have found through experience that the best default in config files is no default :) That way, we (huggingface):
a) don't have to push changes to repositories in the hub in case we find bugs
b) we can easily distinguish defaults from user-defined values that happen to be equal to the default
If the point in 1. is addressed, then no change is needed to the existing default (None
). Defaults in the classes and the validation code are very helpful, though!
@@ -201,3 +229,55 @@ def _rope_scaling_validation(self): | |||
) | |||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: | |||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") | |||
|
|||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._yarn_rope_scaling_validation | |||
def _yarn_rope_scaling_validation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
likewise, the contents of this function should be moved into _rope_scaling_validation
, and the flags should only be checked if the rope scaling method is a yarn one and the flags exist in the dictionary
@@ -162,12 +188,14 @@ def __init__( | |||
self.max_position_embeddings = max_position_embeddings | |||
self.rope_theta = rope_theta | |||
self.rope_scaling = rope_scaling | |||
self.yarn_rope_scaling = yarn_rope_scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.yarn_rope_scaling = yarn_rope_scaling |
(as per the comment above)
extrapolation_factor (`float`, defaults to 1): | ||
Factor to ajust the n-dimensional rotational scaling for extrapolation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this parameter (extrapolation_factor
) is in the original implementation. However, if we dig further, we can see that it is not used in practice (unless I'm missing something -- feel free to correct me!):
- The default value of
1.
does not change the computation - There are no references to it in the yarn paper;
- I couldn't find any Yarn model on the hub that has set this parameter in
config.json
, meaning the default1
is always used; - All references in the original repo use the default value
- In an older PR, the author writes "extrapolation_factor and ntk_factor are used for validation purposes, and should not be changed unless it is necessary."
As such, I believe we can:
- delete this variable from the config
- delete all related code :)
return 1.0 | ||
return 0.1 * math.log(scaling_factor) + 1.0 | ||
|
||
def forward(self, x, seq_len=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the pattern we have in LlamaRotaryEmbedding.forward
-- the pattern was changed a few minor versions ago from the one you have here, where sin
and cos
are cached, to a different one. The new pattern is faster and is compatible with torch.compile
.
From a quick glance: I think you may be able to call super().forward
and simply apply * self.mscale
on the results
Parameter to set the boundary for extrapolation (only) in the linear ramp function. | ||
beta_slow (`float`, *optional*, defaults to 1): | ||
Parameter to set the boundary for interpolation (only) in the linear ramp function. | ||
finetuned (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can also be removed (see the comments on the new dynamic class)
self._sin_cached[:seq_len, ...].to(dtype=x.dtype), | ||
) | ||
|
||
def yarn(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def yarn(self, device): | |
def compute_yarn_scaling(self, device): |
(Or a similar name. Let's use descriptive function names :) )
device, | ||
) | ||
|
||
if finetuned: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if finetuned: | |
if self.max_position_embeddings != self.original_max_position_embeddings: |
This should be true for fine-tuned models, saving us a flag :)
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_cos_long, original_cos_long) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_sin_long, original_sin_long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also check that yarn_sin/cos_short != original_sin/cos_short (i.e. that applying yarn should change all values)
Thank you very much for this in-depth review and suggestions! We'll iterate on it and reach back shortly 🤗 |
Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>
@gante Hi! We believe we've covered the mentioned topics in this iteration. Please tell us what you think! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general LGTM (aside two minor cleanup tasks) 👍 @mig-mfreitas @miguelm-almeida
Before I give the green light, I'd like you to run a sanity check: does the original implementation of YaRN match this implementation? This check doesn't fit as a test (it probably takes too long), but I'd like you to share a script with us that shows that the two implementations generate the same output text :)
After we confirm that they are matching, we can add a slow integration test to ensure we don't regress.
if self.attention_factor is None: | ||
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a note saying that this is the default value according to the yarn paper, so it doesn't look like a magic number :)
self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype()) | ||
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to build these tensors was removed in a recent PR (#30743) -- we can delete them as well as all related functions/variables (emb
, get_pos_embeddings
, ...)
This comment applies to the other class too :)
Chatting offline: yarn is matching the original repo, dynamic yarn is not. Check this gist. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Great work, I think here we have a philosophical choice:
we should start thinking about switching to sin and cos being computed once, at the beginning of the modeling, and that is is.
ROPE scaling is mostly used during generation, thsu generate
should be able to handle any rope scaling + we don't re-compute it every forward layer.
If you want to work on that would be truly nice, and aligned with the FA2 refactor / our idea to add kwargs to the foward.
|
||
if rope_scaling_type not in ["yarn", "dynamic-yarn"]: | ||
return | ||
|
||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6: | ||
raise ValueError( | ||
"`rope_scaling` with type " | ||
f"{rope_scaling_type}" | ||
" must be a dictionary with a maximum of six fields, `type`, `factor`," | ||
"`original_max_position_embeddings`, `attention_factor`, `beta_fast`, `beta_slow`, " | ||
f"got {self.rope_scaling}" | ||
) | ||
original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None) | ||
attention_factor = self.rope_scaling.get("attention_factor", None) | ||
beta_fast = self.rope_scaling.get("beta_fast", None) | ||
beta_slow = self.rope_scaling.get("beta_slow", None) | ||
|
||
if original_max_position_embeddings is not None and not isinstance(original_max_position_embeddings, int): | ||
raise ValueError( | ||
f"`rope_scaling`'s original_max_position_embeddings field must be an int, got {original_max_position_embeddings}" | ||
) | ||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: | ||
raise ValueError( | ||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" | ||
) | ||
if beta_fast is not None and not isinstance(beta_fast, float): | ||
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") | ||
if beta_slow is not None and not isinstance(beta_slow, float): | ||
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") | ||
|
||
b_fast = beta_fast if beta_fast is not None else 32 | ||
b_slow = beta_slow if beta_slow is not None else 1 | ||
if b_fast < b_slow: | ||
raise ValueError( | ||
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={b_fast} and beta_slow={b_slow}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a reason to have this in the LlamaConfig @gante, I think we need a better way to support rope customisation, maybe, but not by changing our current modeling codes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed (but let's do it in a separate PR, to avoid adding library reorganization tasks on top of a 2 month old contributor PR :P )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but these ROPE can't be added in Llama as is 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty much against our philosophy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have to add code here!
- Comply with the the tensor building logic introduced in huggingface#30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
Chatted with @ArthurZucker offline: let's prepare to merge this PR as is, after removing In parallel, I will be working on a PR on top of this one to move RoPE scaling code out of the models themselves, into a dedicated place. RoPE scaling would then behave as a plug in to models with RoPE. When the two PRs are ready: this PR will get merged and, straight away, mine is applied on top of it. We'll then have YaRN from @mig-mfreitas and the new structure from my PR :) |
Done! 🤗🚀 |
@mig-mfreitas working on the new PR on top of this one today 🤗 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
😊😊😊 |
Thank you very much for all the support! |
Thank you for all the guidance, this contribution has been a great experience 🤗🚀 |
What does this PR do?
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes.
Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments.
We implement YaRN and Dynamic-YaRN for the following list of models:
New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs.
For more details, please refer to https://arxiv.org/abs/2309.00071.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante