-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support fixed ntk rope in modeling_rope_utils.py #32424
Closed
Closed
Changes from 2 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -158,6 +158,56 @@ def _compute_dynamic_ntk_parameters( | |||||
return inv_freq, attention_factor | ||||||
|
||||||
|
||||||
def _compute_ntk_parameters( | ||||||
config: Optional[PretrainedConfig] = None, | ||||||
device: Optional["torch.device"] = None, | ||||||
seq_len: Optional[int] = None, | ||||||
**rope_kwargs, | ||||||
) -> Tuple["torch.Tensor", float]: | ||||||
""" | ||||||
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla | ||||||
Args: | ||||||
config ([`~transformers.PretrainedConfig`]): | ||||||
The model configuration. | ||||||
device (`torch.device`): | ||||||
The device to use for initialization of the inverse frequencies. | ||||||
seq_len (`int`, *optional*): | ||||||
The current sequence length, used to update the dynamic RoPE at inference time. | ||||||
rope_kwargs (`Dict`, *optional*): | ||||||
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. | ||||||
Returns: | ||||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the | ||||||
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). | ||||||
""" | ||||||
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling | ||||||
if config is not None and len(rope_kwargs) > 0: | ||||||
raise ValueError( | ||||||
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " | ||||||
f"`_compute_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" | ||||||
) | ||||||
if len(rope_kwargs) > 0: | ||||||
base = rope_kwargs["base"] | ||||||
dim = rope_kwargs["dim"] | ||||||
max_position_embeddings = rope_kwargs["max_position_embeddings"] | ||||||
factor = rope_kwargs["factor"] | ||||||
elif config is not None: | ||||||
base = config.rope_theta | ||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 | ||||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) | ||||||
max_position_embeddings = config.max_position_embeddings | ||||||
factor = config.rope_scaling["factor"] | ||||||
|
||||||
attention_factor = 1.0 # Unused in this type of RoPE | ||||||
|
||||||
# seq_len: default to max_position_embeddings, e.g. at init time | ||||||
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(unused, as it is not dynamic) |
||||||
|
||||||
# Compute the inverse frequencies | ||||||
base = base * factor ** (dim / (dim - 2)) | ||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) | ||||||
return inv_freq, attention_factor | ||||||
|
||||||
|
||||||
def _compute_yarn_parameters( | ||||||
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs | ||||||
) -> Tuple["torch.Tensor", float]: | ||||||
|
@@ -350,6 +400,7 @@ def _compute_llama3_parameters( | |||||
"default": _compute_default_rope_parameters, | ||||||
"linear": _compute_linear_scaling_rope_parameters, | ||||||
"dynamic": _compute_dynamic_ntk_parameters, | ||||||
"ntk": _compute_ntk_parameters, | ||||||
"yarn": _compute_yarn_parameters, | ||||||
"longrope": _compute_longrope_parameters, | ||||||
"llama3": _compute_llama3_parameters, | ||||||
|
@@ -409,6 +460,19 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): | |||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") | ||||||
|
||||||
|
||||||
def _validate_ntk_scaling_rope_parameters(config: PretrainedConfig): | ||||||
rope_scaling = config.rope_scaling | ||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" | ||||||
required_keys = {"rope_type", "factor"} | ||||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` | ||||||
received_keys = set(rope_scaling.keys()) | ||||||
_check_received_keys(rope_type, received_keys, required_keys) | ||||||
|
||||||
factor = rope_scaling["factor"] | ||||||
if factor is None or not isinstance(factor, float) or factor < 1.0: | ||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") | ||||||
|
||||||
|
||||||
def _validate_yarn_parameters(config: PretrainedConfig): | ||||||
rope_scaling = config.rope_scaling | ||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" | ||||||
|
@@ -529,6 +593,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): | |||||
"default": _validate_default_rope_parameters, | ||||||
"linear": _validate_linear_scaling_rope_parameters, | ||||||
"dynamic": _validate_dynamic_scaling_rope_parameters, | ||||||
"ntk": _validate_ntk_scaling_rope_parameters, | ||||||
"yarn": _validate_yarn_parameters, | ||||||
"longrope": _validate_longrope_parameters, | ||||||
"llama3": _validate_llama3_parameters, | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.