Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fixed ntk rope in modeling_rope_utils.py #32424

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,48 @@ def _compute_dynamic_ntk_parameters(
return inv_freq, attention_factor


def _compute_ntk_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
factor = rope_kwargs["factor"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
factor = config.rope_scaling["factor"]

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
base = base * factor ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor


def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
Expand Down Expand Up @@ -350,6 +392,7 @@ def _compute_llama3_parameters(
"default": _compute_default_rope_parameters,
"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"ntk": _compute_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
Expand Down Expand Up @@ -409,6 +452,19 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_ntk_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
Expand Down Expand Up @@ -529,6 +585,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
"default": _validate_default_rope_parameters,
"linear": _validate_linear_scaling_rope_parameters,
"dynamic": _validate_dynamic_scaling_rope_parameters,
"ntk": _validate_ntk_scaling_rope_parameters,
"yarn": _validate_yarn_parameters,
"longrope": _validate_longrope_parameters,
"llama3": _validate_llama3_parameters,
Expand Down
44 changes: 43 additions & 1 deletion tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_rope_validation(self):

# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"factor": ["linear", "ntk", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
Expand Down Expand Up @@ -117,6 +117,24 @@ def test_dynamic_rope_function_bc(self):
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)

def test_ntk_rope_function_bc(self):
config = LlamaConfig()
config.rope_scaling = {"rope_type": "ntk", "factor": 10.0}
device = torch_device

rope_kwargs = {
"rope_type": "dynamic",
"dim": config.hidden_size // config.num_attention_heads,
"max_position_embeddings": config.max_position_embeddings,
"base": config.rope_theta,
"factor": 10.0,
}

rope_fn = ROPE_INIT_FUNCTIONS["ntk"]
config_freqs = rope_fn(config=config, device=device)[0]
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)

def test_default_rope_numerically(self):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
Expand Down Expand Up @@ -217,6 +235,30 @@ def test_dynamic_rope_numerically(self):
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)

def test_ntk_rope_numerically(self):
# This is a ntk scaling strategy,

# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_scaling, None)
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))

default_rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)

rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)

dim = int(config.hidden_size // config.num_attention_heads)
for factor in (2.0, 10.0, 20.0):
config.rope_scaling = {"rope_type": "ntk", "factor": factor}
inv_freq, _ = rope_fn(config=config, device=torch_device)
with self.assertRaises(AssertionError): # It is NOT a linear factor
torch.testing.assert_close(inv_freq, default_inv_freq / (factor ** (dim / (dim - 2))))

def test_yarn_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
Expand Down