Skip to content

Commit

Permalink
Refactor Tensor Building Logic for YaRN
Browse files Browse the repository at this point in the history
- Comply with the the tensor building logic introduced in huggingface#30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
  • Loading branch information
miguelm-almeida and mig-mfreitas committed Jul 10, 2024
1 parent 85552b3 commit d84baa9
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 100 deletions.
32 changes: 32 additions & 0 deletions 2.14.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Requirement already satisfied: datasets in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (2.10.1)
Requirement already satisfied: numpy>=1.17 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=6.0.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (6.0.1)
Requirement already satisfied: dill<0.3.7,>=0.3.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (0.3.4)
Requirement already satisfied: pandas in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (2.2.1)
Requirement already satisfied: requests>=2.19.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (2.31.0)
Requirement already satisfied: tqdm>=4.62.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (4.66.2)
Requirement already satisfied: xxhash in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (3.4.1)
Requirement already satisfied: multiprocess in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (0.70.12.2)
Requirement already satisfied: fsspec>=2021.11.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from fsspec[http]>=2021.11.1->datasets) (2024.2.0)
Requirement already satisfied: aiohttp in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (3.9.3)
Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (0.21.2)
Requirement already satisfied: packaging in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (23.2)
Requirement already satisfied: responses<0.19 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (0.18.0)
Requirement already satisfied: pyyaml>=5.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from datasets) (6.0.1)
Requirement already satisfied: aiosignal>=1.1.2 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (1.9.4)
Requirement already satisfied: async-timeout<5.0,>=4.0 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from aiohttp->datasets) (4.0.3)
Requirement already satisfied: filelock in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.13.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.10.0)
Requirement already satisfied: charset-normalizer<4,>=2 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from requests>=2.19.0->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from requests>=2.19.0->datasets) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from requests>=2.19.0->datasets) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from requests>=2.19.0->datasets) (2024.2.2)
Requirement already satisfied: colorama in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from tqdm>=4.62.1->datasets) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: six>=1.5 in c:\users\migue\anaconda3\envs\transformers\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
12 changes: 6 additions & 6 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ class LlamaConfig(PretrainedConfig):
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports four scaling
strategies: linear, dynamic, yarn and dynamic-yarn. Their scaling factor must be a float greater than 1. The expected format is
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear, dynamic and yarn. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
For `yarn` and `dynamic-yarn` strategies, the dictionary may also contain the following fields:
For the `yarn` strategy, the dictionary may also contain the following fields:
`original_max_position_embeddings` (`int`, *optional*):
The original maximum sequence length. This is used to scale the RoPE embeddings.
`attention_factor` (`float`, *optional*):
Expand Down Expand Up @@ -194,14 +194,14 @@ def _rope_scaling_validation(self):
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn", "dynamic-yarn"]:
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic", "yarn"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn', 'dynamic-yarn'], got {rope_scaling_type}"
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic', 'yarn'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

if rope_scaling_type not in ["yarn", "dynamic-yarn"]:
if rope_scaling_type != "yarn":
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) > 6:
Expand Down
75 changes: 2 additions & 73 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,24 +170,12 @@ def __init__(
self.beta_slow = beta_slow

if self.attention_factor is None:
# Recommended attention factor for LLaMA models.
# For more details please refer to https://arxiv.org/pdf/2309.00071, Eq. 22.
self.attention_factor = 0.1 * math.log(scaling_factor) + 1.0

self.compute_yarn_scaling(device)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())

# Get positional embeddings based on the current max sequence length
def get_pos_embeddings(self, device):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb

# Inverse dimension formula to find the dimension based on the number of rotations
def find_correction_dim(self, num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
Expand Down Expand Up @@ -232,57 +220,6 @@ def compute_yarn_scaling(self, device):
self.mscale = self.attention_factor


class LlamaDynamicYarnScalingRotaryEmbedding(LlamaYarnScalingRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1,
original_max_position_embeddings=2048,
attention_factor=None,
beta_fast=32,
beta_slow=1,
device=None,
):
super().__init__(
dim,
max_position_embeddings,
base,
scaling_factor,
original_max_position_embeddings,
attention_factor,
beta_fast,
beta_slow,
device,
)

if self.max_position_embeddings != self.original_max_position_embeddings:
self.scaling_factor = self.max_position_embeddings / self.original_max_position_embeddings
self.compute_yarn_scaling(device)
else:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.mscale = 1

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
emb = self.get_pos_embeddings(device)

self._cos_cached = (emb.cos() * self.mscale)[None, :, :].to(torch.get_default_dtype())
self._sin_cached = (emb.sin() * self.mscale)[None, :, :].to(torch.get_default_dtype())

def forward(self, x, position_ids=None):
# Difference to the standard YaRN: the scaling factor is updated when the max sequence length is exceeded
# x: [bs, num_attention_heads, seq_len, head_size]
seq_len = torch.max(position_ids) + 1
self.scaling_factor = seq_len / self.original_max_position_embeddings
self.compute_yarn_scaling(x.device)

cos, sin = super().forward(x, position_ids)
return cos, sin


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -440,14 +377,6 @@ def _init_rope(self):
base=self.rope_theta,
**kwargs,
)
elif scaling_type == "dynamic-yarn":
self.rotary_emb = LlamaDynamicYarnScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
21 changes: 1 addition & 20 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaDynamicYarnScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
LlamaYarnScalingRotaryEmbedding,
Expand Down Expand Up @@ -399,7 +398,7 @@ def test_llama_token_classification_model(self):
def test_save_load_fast_init_from_base(self):
pass

@parameterized.expand([("linear",), ("dynamic",), ("yarn",), ("dynamic-yarn",)])
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
Expand Down Expand Up @@ -513,24 +512,6 @@ def test_model_rope_scaling(self):
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)

# Sanity check Dynamic Yarn RoPE scaling
dynamic_yarn_scaling_rope = LlamaDynamicYarnScalingRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
).to(torch_device)
dynamic_yarn_cos_short, dynamic_yarn_sin_short = dynamic_yarn_scaling_rope(x, position_ids_short)
dynamic_yarn_cos_long, dynamic_yarn_sin_long = dynamic_yarn_scaling_rope(x, position_ids_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(dynamic_yarn_sin_long, original_sin_long)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down
2 changes: 1 addition & 1 deletion tests/models/stablelm/test_modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)


# Copied from transformers.tests.models.StableLm.test_modeling_StableLm.StableLmModelTester with StableLm -> StableLm
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
class StableLmModelTester:
# Ignore copy
def __init__(
Expand Down

0 comments on commit d84baa9

Please sign in to comment.