-
Notifications
You must be signed in to change notification settings - Fork 27.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Modular: support for importing functions from any file (#35692)
* fix function imports * improve comment * Update modeling_switch_function.py * make checks more robust * improvement * rename * final test update
- Loading branch information
1 parent
8ebe9d7
commit 91be6a5
Showing
10 changed files
with
305 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
# This file was automatically generated from examples/modular-transformers/modular_add_function.py. | ||
# Do NOT edit this file manually as any edits will be overwritten by the generation of | ||
# the file from the modular. If any change should be done, please apply the change to the | ||
# modular_add_function.py file directly. One of our CI enforces this. | ||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 | ||
# Note that zamba does not have the `apply_rotary_pos_emb` function! | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
def rotate_half(x): | ||
"""Rotates half the hidden dims of the input.""" | ||
x1 = x[..., : x.shape[-1] // 2] | ||
x2 = x[..., x.shape[-1] // 2 :] | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
|
||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||
"""Applies Rotary Position Embedding to the query and key tensors. | ||
Args: | ||
q (`torch.Tensor`): The query tensor. | ||
k (`torch.Tensor`): The key tensor. | ||
cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||
sin (`torch.Tensor`): The sine part of the rotary embedding. | ||
position_ids (`torch.Tensor`, *optional*): | ||
Deprecated and unused. | ||
unsqueeze_dim (`int`, *optional*, defaults to 1): | ||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||
Returns: | ||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||
""" | ||
cos = cos.unsqueeze(unsqueeze_dim) | ||
sin = sin.unsqueeze(unsqueeze_dim) | ||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
return q_embed, k_embed | ||
|
||
|
||
class TestAttention(nn.Module): | ||
""" | ||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer | ||
and "Generating Long Sequences with Sparse Transformers". | ||
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: | ||
The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. | ||
The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer | ||
(see fig. 2 in https://arxiv.org/pdf/2405.16712). | ||
Additionally, replaced | ||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with | ||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) | ||
""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def forward(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
_ = apply_rotary_pos_emb(1, 1, 1, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.