From 6fd93fe93abeb5c66feaead859b57ff6356dc442 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Sat, 30 Mar 2024 08:30:52 -0700 Subject: [PATCH] Fix rope theta for OpenLlama (#29893) fix: rope_theta for open llama --- .../models/deprecated/open_llama/configuration_open_llama.py | 4 ++++ .../models/deprecated/open_llama/modeling_open_llama.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py index 6b56c352bf78..0111e031251a 100644 --- a/src/transformers/models/deprecated/open_llama/configuration_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -66,6 +66,8 @@ class OpenLlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is @@ -113,6 +115,7 @@ def __init__( attention_dropout_prob=0.1, use_stable_embedding=True, shared_input_output_embedding=True, + rope_theta=10000.0, rope_scaling=None, **kwargs, ): @@ -133,6 +136,7 @@ def __init__( self.attention_dropout_prob = attention_dropout_prob self.use_stable_embedding = use_stable_embedding self.shared_input_output_embedding = shared_input_output_embedding + self.rope_theta = rope_theta self.rope_scaling = rope_scaling self._rope_scaling_validation() diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 71c42447cd2b..098f8c7da50d 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -214,6 +214,7 @@ def __init__(self, config: OpenLlamaConfig): self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings self.dropout_prob = config.attention_dropout_prob + self.rope_theta = config.rope_theta if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError(