Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fixed ntk rope in modeling_rope_utils.py #32424

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,56 @@ def _compute_dynamic_ntk_parameters(
return inv_freq, attention_factor


def _compute_ntk_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length, used to update the dynamic RoPE at inference time.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The current sequence length, used to update the dynamic RoPE at inference time.
The current sequence length. Unused for this type of RoPE.

rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
max_position_embeddings = rope_kwargs["max_position_embeddings"]
factor = rope_kwargs["factor"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]

attention_factor = 1.0 # Unused in this type of RoPE

# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
Copy link
Member

@gante gante Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings

(unused, as it is not dynamic)


# Compute the inverse frequencies
base = base * factor ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor


def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
Expand Down Expand Up @@ -350,6 +400,7 @@ def _compute_llama3_parameters(
"default": _compute_default_rope_parameters,
"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"ntk": _compute_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
Expand Down Expand Up @@ -409,6 +460,19 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_ntk_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
Expand Down Expand Up @@ -529,6 +593,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
"default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_dynamic_scaling_rope_parameters,
"ntk": _validate_ntk_scaling_rope_parameters,
"yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters,
"llama3": _validate_llama3_parameters,
Expand Down
20 changes: 19 additions & 1 deletion tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_rope_validation(self):

# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"factor": ["linear", "ntk", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
Expand Down Expand Up @@ -117,6 +117,24 @@ def test_dynamic_rope_function_bc(self):
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)

def test_ntk_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "ntk", "factor": 10.0}
device = torch_device

rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}

rope_fn = ROPE_INIT_FUNCTIONS["ntk"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)

def test_default_rope_numerically(self):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
Expand Down