From 7d946316517e2def41b08c87db240e5c2836076c Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 7 Aug 2024 06:40:50 -0700 Subject: [PATCH 1/3] use head_dim if in config for RoPE --- src/transformers/modeling_rope_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 839adaecd0ca..f047a5f4d893 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -58,7 +58,8 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -143,7 +144,8 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -185,7 +187,8 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -265,7 +268,8 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] factor = config.rope_scaling.get("factor") @@ -450,7 +454,8 @@ def _validate_longrope_parameters(config: PretrainedConfig): _check_received_keys(rope_type, received_keys, required_keys, optional_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): From 17e6b1d8b010fbba15fdcb57a7c8a72d3e17951f Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 7 Aug 2024 06:56:53 -0700 Subject: [PATCH 2/3] typo --- src/transformers/modeling_rope_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index f047a5f4d893..67141f1764c0 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -58,7 +58,7 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -144,7 +144,7 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -187,7 +187,7 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -268,7 +268,7 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] @@ -454,7 +454,7 @@ def _validate_longrope_parameters(config: PretrainedConfig): _check_received_keys(rope_type, received_keys, required_keys, optional_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") From 1d688fbe9e4071ce16458619ed26514bbc53be90 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Thu, 15 Aug 2024 08:17:52 -0700 Subject: [PATCH 3/3] simplify with getattr --- src/transformers/modeling_rope_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 67141f1764c0..c09664d688c3 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -58,7 +58,7 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -144,7 +144,7 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -187,7 +187,7 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -268,7 +268,7 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] @@ -454,7 +454,7 @@ def _validate_longrope_parameters(config: PretrainedConfig): _check_received_keys(rope_type, received_keys, required_keys, optional_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor")