From 3994fa5bafa56db6581d962d562f3c54fac291df Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 Feb 2024 09:47:41 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Llama:=20update=20rope=20scaling?= =?UTF-8?q?=20to=20match=20static=20cache=20changes=20(#29143)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../open_llama/modeling_open_llama.py | 4 +- .../models/falcon/modeling_falcon.py | 6 +- .../models/llama/modeling_llama.py | 59 ++++++++----------- .../models/persimmon/modeling_persimmon.py | 4 +- src/transformers/models/phi/modeling_phi.py | 4 +- .../models/stablelm/modeling_stablelm.py | 4 +- tests/models/llama/test_modeling_llama.py | 1 - 7 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index d2ea931a44f..71c42447cd2 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -100,7 +100,7 @@ def forward(self, x, seq_len=None): ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama +# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->OpenLlama class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -120,7 +120,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->OpenLlama class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 9767b797b00..7ef857748ca 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -167,7 +167,8 @@ def forward(self, x, seq_len=None): ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon +# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): """FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -187,7 +188,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon +# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon +# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied) class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): """FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9e2efe79d9b..5fb7e8459a2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -94,7 +94,6 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -118,6 +117,9 @@ def cos_cached(self): return self._cos_cached def forward(self, x, position_ids, seq_len=None): + if seq_len is not None: + logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.") + # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() @@ -138,16 +140,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + def forward(self, x, position_ids, seq_len=None): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids, seq_len) + return cos, sin class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -157,23 +154,20 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - + def forward(self, x, position_ids, seq_len=None): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + cos, sin = super().forward(x, position_ids, seq_len) + return cos, sin def rotate_half(x): @@ -183,7 +177,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -191,9 +185,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -360,8 +353,8 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache @@ -447,8 +440,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_key_value = getattr(self, "past_key_value", past_key_value) @@ -645,8 +638,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_key_value = getattr(self, "past_key_value", past_key_value) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index f0de7ef2934..c83ba413952 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -77,7 +77,7 @@ def forward(self, x, seq_len=None): ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon +# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): """PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -97,7 +97,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): """PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index b4d261d07f4..9704d4ccf52 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -120,7 +120,7 @@ def forward(self, x, seq_len=None): ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi +# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -140,7 +140,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 9baaac1f513..00b02b1431a 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -103,7 +103,7 @@ def forward(self, x, seq_len=None): ) -# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm +# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding): """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -123,7 +123,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm +# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding): """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 4efc5da5c40..a393950232f 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -362,7 +362,6 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - @unittest.skip("TODO @gante fix this for Llama") def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size)