Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for partial rotary embeddings in Phi3 model #35947

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/transformers/models/phi3/configuration_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Phi3Config(PretrainedConfig):
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
partial_rotary_factor (`float`, *optional*, defaults to 1.0):
Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 32000):
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=1.0,
bos_token_id=1,
eos_token_id=32000,
pad_token_id=32000,
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window
Expand Down Expand Up @@ -210,9 +214,10 @@ def _rope_scaling_validation(self):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
if not len(rope_scaling_short_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
Expand All @@ -221,9 +226,9 @@ def _rope_scaling_validation(self):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
if not len(rope_scaling_long_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
)


Expand Down
59 changes: 32 additions & 27 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,33 +82,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand Down Expand Up @@ -147,6 +120,38 @@ def eager_attention_forward(
return attn_output, attn_weights


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
return q_embed, k_embed


class Phi3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down
34 changes: 33 additions & 1 deletion src/transformers/models/phi3/modular_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
MistralForTokenClassification,
MistralPreTrainedModel,
MistralRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
rotate_half,
)
from .configuration_phi3 import Phi3Config

Expand Down Expand Up @@ -64,6 +64,38 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.down_proj(up_states)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]

q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
return q_embed, k_embed


class Phi3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand Down