Skip to content

Commit

Permalink
Adding [T5/MT5/UMT5]ForTokenClassification (#28443)
Browse files Browse the repository at this point in the history
* Adding [T5/MT5/UMT5]ForTokenClassification

* Add auto mappings for T5ForTokenClassification and variants

* Adding ForTokenClassification to the list of models

* Adding attention_mask param to the T5ForTokenClassification test

* Remove outdated comment in test

* Adding EncoderOnly and Token Classification tests for MT5 and UMT5

* Fix typo in umt5 string

* Add tests for all the existing MT5 models

* Fix wrong comment in dependency_versions_table

* Reverting change to common test for _keys_to_ignore_on_load_missing

The test is correctly picking up redundant keys in _keys_to_ignore_on_load_missing.

* Removing _keys_to_ignore_on_missing from MT5 since the key is not used in the model

* Add fix-copies to MT5ModelTest
  • Loading branch information
hackyon authored Feb 1, 2024
1 parent 7b2bd1f commit 0d26abd
Show file tree
Hide file tree
Showing 18 changed files with 1,579 additions and 54 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/mt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ See [`T5TokenizerFast`] for all details.

[[autodoc]] MT5ForSequenceClassification

## MT5ForTokenClassification

[[autodoc]] MT5ForTokenClassification

## MT5ForQuestionAnswering

[[autodoc]] MT5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/t5.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] T5ForSequenceClassification
- forward

## T5ForTokenClassification

[[autodoc]] T5ForTokenClassification
- forward

## T5ForQuestionAnswering

[[autodoc]] T5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/umt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ Refer to [T5's documentation page](t5) for more tips, code examples and notebook
[[autodoc]] UMT5ForSequenceClassification
- forward

## UMT5ForTokenClassification

[[autodoc]] UMT5ForTokenClassification
- forward

## UMT5ForQuestionAnswering

[[autodoc]] UMT5ForQuestionAnswering
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tasks/token_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)

<!--End of the generated tip-->

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,7 @@
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
]
Expand Down Expand Up @@ -3299,6 +3300,7 @@
"T5ForConditionalGeneration",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
"T5Model",
"T5PreTrainedModel",
"load_tf_weights_in_t5",
Expand Down Expand Up @@ -3370,6 +3372,7 @@
"UMT5ForConditionalGeneration",
"UMT5ForQuestionAnswering",
"UMT5ForSequenceClassification",
"UMT5ForTokenClassification",
"UMT5Model",
"UMT5PreTrainedModel",
]
Expand Down Expand Up @@ -7223,6 +7226,7 @@
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
)
Expand Down Expand Up @@ -7688,6 +7692,7 @@
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
Expand Down Expand Up @@ -7743,6 +7748,7 @@
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5ForTokenClassification,
UMT5Model,
UMT5PreTrainedModel,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@
("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"),
("mt5", "MT5ForTokenClassification"),
("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("phi", "PhiForTokenClassification"),
Expand All @@ -960,6 +961,8 @@
("roc_bert", "RoCBertForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("t5", "T5ForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
"MT5Stack",
Expand Down Expand Up @@ -88,6 +89,7 @@
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
MT5Stack,
Expand Down
30 changes: 10 additions & 20 deletions src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class MT5Config(PretrainedConfig):

model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}

def __init__(
self,
Expand All @@ -97,15 +98,6 @@ def __init__(
classifier_dropout=0.0,
**kwargs,
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
Expand Down Expand Up @@ -139,17 +131,15 @@ def __init__(
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"

@property
def hidden_size(self):
return self.d_model

@property
def num_attention_heads(self):
return self.num_heads

@property
def num_hidden_layers(self):
return self.num_layers
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)


class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
Expand Down
93 changes: 92 additions & 1 deletion src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
Expand All @@ -54,6 +55,19 @@
_CHECKPOINT_FOR_DOC = "mt5-small"


####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
MT5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/mt5-small",
"google/mt5-base",
"google/mt5-large",
"google/mt5-xl",
"google/mt5-xxl",
# See all mT5 models at https://huggingface.co/models?filter=mt5
]

PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.
Expand Down Expand Up @@ -804,6 +818,10 @@ def _init_weights(self, module):
if hasattr(module, "qa_outputs"):
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
module.qa_outputs.bias.data.zero_()
elif isinstance(module, MT5ForTokenClassification):
if hasattr(module, "classifier"):
module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.classifier.bias.data.zero_()
elif isinstance(module, MT5ClassificationHead):
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.dense, "bias") and module.dense.bias is not None:
Expand Down Expand Up @@ -1334,7 +1352,6 @@ class MT5Model(MT5PreTrainedModel):

model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

Expand Down Expand Up @@ -2158,6 +2175,80 @@ def forward(
)


@add_start_docstrings(
"""
MT5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
MT5_START_DOCSTRING,
)
class MT5ForTokenClassification(MT5PreTrainedModel):
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]

# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5
def __init__(self, config: MT5Config):
super().__init__(config)
self.num_labels = config.num_labels

self.transformer = MT5EncoderModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits, outputs[2:-1])
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


@add_start_docstrings(
"""
MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"load_tf_weights_in_t5",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
]

try:
Expand Down Expand Up @@ -119,6 +120,7 @@
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
Expand Down
Loading

0 comments on commit 0d26abd

Please sign in to comment.