Skip to content

Commit

Permalink
modular_model_converter bugfix on assignments (huggingface#35642)
Browse files Browse the repository at this point in the history
* added bugfix in modular converter to keep modular assignments for docstrings, expected outputs etc.

* revert stracoder2 docstring copying, add forward in EMU3 to enable docstring assingment, remove verbatim assignments in modular converter

* added _FOR_DOC in assignments to keep, corrected wrong checkpoint name in ijepa's configuration
  • Loading branch information
nikosanto13 authored Jan 21, 2025
1 parent 234168c commit 920f34a
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 98 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "BambaConfig"


Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "CohereConfig"


Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "Cohere2Config"


Expand Down
166 changes: 85 additions & 81 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


EMU3_INPUTS_DOCSTRING = r"""
EMU3_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
Expand Down Expand Up @@ -1292,19 +1292,15 @@ def forward(self, x, position_ids):
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Has to be an instance of [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
Expand Down Expand Up @@ -1366,7 +1362,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1598,77 +1594,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


EMU3_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Has to be an instance of [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
Expand Down Expand Up @@ -1790,6 +1715,85 @@ def forward(
)


EMU3_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
[`Emu3ImageProcessor`] for processing images).
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
[`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
[`Emu3ImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Has to be an instance of [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,10 @@ def __init__(self, config: Emu3Config):
[Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)

@add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
def forward(self, **super_kwargs):
super().forward(**super_kwargs)


class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
config_class = Emu3TextConfig
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ijepa/configuration_ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class IJepaConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the I-JEPA
[google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture.
[facebook/ijepa_vith14_1k](https://huggingface.co/facebook/ijepa_vith14_1k) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/ijepa/modeling_ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ def forward(self, hidden_states):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]


_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]


IJEPA_START_DOCSTRING = r"""
Expand Down Expand Up @@ -640,8 +642,7 @@ def forward(
)


# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
_IMAGE_CLASS_CHECKPOINT = "facebook/ijepa_vith14_1k"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"


Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MoonshineConfig"


Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "meta-phi/Phi-2-7b-hf"
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig"


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/phi/modular_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig"


class PhiAttention(LlamaAttention):
def __init__(self, config: PhiConfig, layer_idx: int):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"

_CONFIG_FOR_DOC = "Starcoder2Config"


Expand Down
22 changes: 11 additions & 11 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def merge_docstrings(original_docstring, updated_docstring):
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level*' '}```",
f"\n{original_level * ' '}```",
parts[1],
"```",
parts[2],
Expand Down Expand Up @@ -515,10 +515,8 @@ def forward(...):
return all_dependencies_with_parent


# These top-level variables will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_TO_KEEP = {
"_CHECKPOINT_FOR_DOC",
}
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]


class ClassDependencyMapper(CSTVisitor):
Expand Down Expand Up @@ -828,12 +826,14 @@ def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: di
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file.
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is
in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches
a pattern in `ASSIGNMENTS_REGEX_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
big docstrings.
"""
for assignment, node in assignments.items():
if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments:
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP)

if should_keep or assignment not in self.assignments:
self.assignments[assignment] = node
if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment]
Expand Down Expand Up @@ -1404,7 +1404,7 @@ class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
]
if len(modeling_bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}."
)
if len(modeling_bases) == 1:
filename = self.model_specific_imported_objects[modeling_bases[0]]
Expand Down Expand Up @@ -1432,7 +1432,7 @@ class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
if final_name != cased_default_name and has_prefix_collision:
if len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the "
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
Expand All @@ -1448,7 +1448,7 @@ class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
final_name = cased_default_name
elif len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only "
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
"in all the modular (best)."
Expand Down

0 comments on commit 920f34a

Please sign in to comment.