From 2d9887570a37d1be78f56215df24101bfe395260 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 24 Jan 2025 02:10:15 +0000 Subject: [PATCH 1/3] Added support for partial_rotary_factor --- src/transformers/models/phi3/configuration_phi3.py | 13 +++++++++---- src/transformers/models/phi3/modeling_phi3.py | 9 +++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 361c43c99eca..96774fe8c969 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -81,6 +81,8 @@ class Phi3Config(PretrainedConfig): contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to 1.0): + Percentage of the query and keys which will have rotary embedding. bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 32000): @@ -128,6 +130,7 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, + partial_rotary_factor=1, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, @@ -155,6 +158,7 @@ def __init__( self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_adjustment() self._rope_scaling_validation() self.sliding_window = sliding_window @@ -204,9 +208,10 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" ) - if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: raise ValueError( - f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" ) if not ( isinstance(rope_scaling_long_factor, list) @@ -215,9 +220,9 @@ def _rope_scaling_validation(self): raise ValueError( f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" ) - if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + if not len(rope_scaling_long_factor) == rotary_ndims // 2: raise ValueError( - f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e86e028b4027..151c0163d755 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -104,8 +104,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) return q_embed, k_embed From 66e25a8d19ddb7a09635c2f99155b336d2d04ffd Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 24 Jan 2025 17:23:38 +0000 Subject: [PATCH 2/3] addressed comments --- src/transformers/models/phi3/configuration_phi3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py index 96774fe8c969..65f6ff21bfa4 100644 --- a/src/transformers/models/phi3/configuration_phi3.py +++ b/src/transformers/models/phi3/configuration_phi3.py @@ -82,7 +82,7 @@ class Phi3Config(PretrainedConfig): the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size divided by the number of attention heads divided by 2. partial_rotary_factor (`float`, *optional*, defaults to 1.0): - Percentage of the query and keys which will have rotary embedding. + Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0. bos_token_id (`int`, *optional*, defaults to 1): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 32000): @@ -130,7 +130,7 @@ def __init__( tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, - partial_rotary_factor=1, + partial_rotary_factor=1.0, bos_token_id=1, eos_token_id=32000, pad_token_id=32000, From a8790ab2fbef99a24c67040a0a3ee1a6d15801d7 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Tue, 28 Jan 2025 22:35:17 +0000 Subject: [PATCH 3/3] refactored --- src/transformers/models/phi3/modeling_phi3.py | 64 +++++++++---------- src/transformers/models/phi3/modular_phi3.py | 34 +++++++++- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 151c0163d755..3ee5d8c5d0d8 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -82,38 +82,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - rotary_dim = cos.shape[-1] - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) - k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -152,6 +120,38 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + class Phi3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 27f7c42f5bb8..caddb77f6fcd 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -34,8 +34,8 @@ MistralForTokenClassification, MistralPreTrainedModel, MistralRotaryEmbedding, - apply_rotary_pos_emb, eager_attention_forward, + rotate_half, ) from .configuration_phi3 import Phi3Config @@ -64,6 +64,38 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: return self.down_proj(up_states) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + class Phi3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper"""