From 91be6a5eb21f42ed22c414fb9ebdd2e9f344f642 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 16 Jan 2025 16:37:53 +0000 Subject: [PATCH] Modular: support for importing functions from any file (#35692) * fix function imports * improve comment * Update modeling_switch_function.py * make checks more robust * improvement * rename * final test update --- .../modeling_add_function.py | 66 +++++++ .../modular-transformers/modeling_dummy.py | 17 +- .../modeling_multimodal1.py | 17 +- .../modeling_my_new_model2.py | 17 +- .../modeling_new_task_model.py | 1 - .../modular-transformers/modeling_super.py | 17 +- .../modeling_switch_function.py | 170 ++++++++++++++++++ .../modular_add_function.py | 15 ++ .../modular_switch_function.py | 10 ++ utils/modular_model_converter.py | 18 +- 10 files changed, 305 insertions(+), 43 deletions(-) create mode 100644 examples/modular-transformers/modeling_add_function.py create mode 100644 examples/modular-transformers/modeling_switch_function.py create mode 100644 examples/modular-transformers/modular_add_function.py create mode 100644 examples/modular-transformers/modular_switch_function.py diff --git a/examples/modular-transformers/modeling_add_function.py b/examples/modular-transformers/modeling_add_function.py new file mode 100644 index 000000000000..acf140f025d9 --- /dev/null +++ b/examples/modular-transformers/modeling_add_function.py @@ -0,0 +1,66 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_add_function.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_add_function.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Note that zamba does not have the `apply_rotary_pos_emb` function! +from typing import Optional, Tuple + +import torch +from torch import nn + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TestAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + """ + + def __init__(self): + pass + + def forward(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _ = apply_rotary_pos_emb(1, 1, 1, 1) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 382b87bd3847..0c61848924a4 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -45,13 +45,8 @@ def extra_repr(self): class DummyRotaryEmbedding(nn.Module): - def __init__( - self, - config: DummyConfig, - device=None, - ): + def __init__(self, config: DummyConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -63,7 +58,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -75,13 +70,14 @@ def _dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -356,6 +352,7 @@ class DummyPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index df23a83b3411..45b10a5b206a 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -45,13 +45,8 @@ def extra_repr(self): class Multimodal1TextRotaryEmbedding(nn.Module): - def __init__( - self, - config: Multimodal1TextConfig, - device=None, - ): + def __init__(self, config: Multimodal1TextConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -63,7 +58,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -75,13 +70,14 @@ def _dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -356,6 +352,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 9288b1a29305..ae71d724c25a 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -61,13 +61,8 @@ def forward(self, x): class MyNewModel2RotaryEmbedding(nn.Module): - def __init__( - self, - config: MyNewModel2Config, - device=None, - ): + def __init__(self, config: MyNewModel2Config, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -79,7 +74,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -91,13 +86,14 @@ def _dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -356,6 +352,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index da0b354fe76e..f07ac7f3348b 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -107,7 +107,6 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 1f5aa55c4690..e44c4bde1987 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -45,13 +45,8 @@ def extra_repr(self): class SuperRotaryEmbedding(nn.Module): - def __init__( - self, - config: SuperConfig, - device=None, - ): + def __init__(self, config: SuperConfig, device=None): super().__init__() - self.rope_kwargs = {} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -63,7 +58,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -75,13 +70,14 @@ def _dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -356,6 +352,7 @@ class SuperPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/examples/modular-transformers/modeling_switch_function.py b/examples/modular-transformers/modeling_switch_function.py new file mode 100644 index 000000000000..3b89284537ae --- /dev/null +++ b/examples/modular-transformers/modeling_switch_function.py @@ -0,0 +1,170 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_switch_function.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_switch_function.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Note that llama and cohere have different definitions for rotate_half +from typing import Callable, Optional, Tuple + +import torch +from torch import nn + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from .configuration_switch_function import SwitchFunctionConfig + + +logger = logging.get_logger(__name__) + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SwitchFunctionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SwitchFunctionConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/examples/modular-transformers/modular_add_function.py b/examples/modular-transformers/modular_add_function.py new file mode 100644 index 000000000000..6a2426a67236 --- /dev/null +++ b/examples/modular-transformers/modular_add_function.py @@ -0,0 +1,15 @@ +# Note that zamba does not have the `apply_rotary_pos_emb` function! +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.models.zamba.modeling_zamba import ZambaAttention + + +# When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present +# by default as it is absent from the class definition (and the file altogether). +# Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but +# `rotate_half` as well as a dependency from the imported function!! +class TestAttention(ZambaAttention): + def __init__(self): + pass + + def forward(self): + _ = apply_rotary_pos_emb(1, 1, 1, 1) diff --git a/examples/modular-transformers/modular_switch_function.py b/examples/modular-transformers/modular_switch_function.py new file mode 100644 index 000000000000..3c0c716a4397 --- /dev/null +++ b/examples/modular-transformers/modular_switch_function.py @@ -0,0 +1,10 @@ +# Note that llama and cohere have different definitions for rotate_half +from transformers.models.cohere.modeling_cohere import rotate_half # noqa +from transformers.models.llama.modeling_llama import LlamaAttention + + +# When following LlamaAttention dependencies, we will grab the function `rotate_half` defined +# in `modeling_llama.py`. But here we imported it explicitly from Cohere, so it should use Cohere's +# definition instead +class SwitchFunctionAttention(LlamaAttention): + pass diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 7cba82f6df13..8126d130ae72 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -776,7 +776,7 @@ def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, in else: merged_dependencies.append(class_dep) # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) # Add all original node first, then merged ones @@ -801,7 +801,7 @@ def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, in else: original_dependencies.append(dep) # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) # Add all original node first, then merged ones @@ -1321,6 +1321,20 @@ def merge_model_specific_imports(self, visited_modules): self.added_objects_file_mapping[dep] = file self.functions[dep] = visited_module.global_nodes[dep] + # Add/overwrite the imported functions to other visited modules as well, in case it is absent/different + # in he modeling source file of the inherited class. See `examples/modular-tranformers/modular_switch_function.py` + # and `examples/modular-tranformers/modular_add_function.py` for examples + recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set()) + node_recursive_dependencies_mapping = { + dep: visited_module.global_nodes[dep] for dep in recursive_dependencies + } + for filename, module_mapper in self.visited_modules.items(): + if filename != file: + module_mapper.global_nodes[object_name] = visited_module.functions[object_name] + if len(recursive_dependencies) > 0: + module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies + module_mapper.global_nodes.update(node_recursive_dependencies_mapping) + # Add assignments and their dependencies elif object_name in visited_module.assignments and object_name not in self.assignments: self.assignments[object_name] = visited_module.assignments[object_name]