Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding [T5/MT5/UMT5]ForTokenClassification #28443

Merged
merged 12 commits into from
Feb 1, 2024
Merged
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/mt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ See [`T5TokenizerFast`] for all details.

[[autodoc]] MT5ForSequenceClassification

## MT5ForTokenClassification

[[autodoc]] MT5ForTokenClassification

## MT5ForQuestionAnswering

[[autodoc]] MT5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/t5.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] T5ForSequenceClassification
- forward

## T5ForTokenClassification

[[autodoc]] T5ForTokenClassification
- forward

## T5ForQuestionAnswering

[[autodoc]] T5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/umt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ Refer to [T5's documentation page](t5) for more tips, code examples and notebook
[[autodoc]] UMT5ForSequenceClassification
- forward

## UMT5ForTokenClassification

[[autodoc]] UMT5ForTokenClassification
- forward

## UMT5ForQuestionAnswering

[[autodoc]] UMT5ForQuestionAnswering
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tasks/token_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [BROS](../model_doc/bros), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [Phi](../model_doc/phi), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)

<!--End of the generated tip-->

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,7 @@
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
]
Expand Down Expand Up @@ -3299,6 +3300,7 @@
"T5ForConditionalGeneration",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
"T5Model",
"T5PreTrainedModel",
"load_tf_weights_in_t5",
Expand Down Expand Up @@ -3370,6 +3372,7 @@
"UMT5ForConditionalGeneration",
"UMT5ForQuestionAnswering",
"UMT5ForSequenceClassification",
"UMT5ForTokenClassification",
"UMT5Model",
"UMT5PreTrainedModel",
]
Expand Down Expand Up @@ -7223,6 +7226,7 @@
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
)
Expand Down Expand Up @@ -7688,6 +7692,7 @@
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
Expand Down Expand Up @@ -7743,6 +7748,7 @@
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5ForTokenClassification,
UMT5Model,
UMT5PreTrainedModel,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@
("mpnet", "MPNetForTokenClassification"),
("mpt", "MptForTokenClassification"),
("mra", "MraForTokenClassification"),
("mt5", "MT5ForTokenClassification"),
("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("phi", "PhiForTokenClassification"),
Expand All @@ -960,6 +961,8 @@
("roc_bert", "RoCBertForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("t5", "T5ForTokenClassification"),
("umt5", "UMT5ForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5ForTokenClassification",
"MT5Model",
"MT5PreTrainedModel",
"MT5Stack",
Expand Down Expand Up @@ -88,6 +89,7 @@
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5ForTokenClassification,
MT5Model,
MT5PreTrainedModel,
MT5Stack,
Expand Down
30 changes: 10 additions & 20 deletions src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class MT5Config(PretrainedConfig):

model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}

def __init__(
self,
Expand All @@ -97,15 +98,6 @@ def __init__(
classifier_dropout=0.0,
**kwargs,
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
Expand Down Expand Up @@ -139,17 +131,15 @@ def __init__(
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"

@property
def hidden_size(self):
return self.d_model

@property
def num_attention_heads(self):
return self.num_heads

@property
def num_hidden_layers(self):
return self.num_layers
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)


class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
Expand Down
93 changes: 92 additions & 1 deletion src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
Expand All @@ -54,6 +55,19 @@
_CHECKPOINT_FOR_DOC = "mt5-small"


####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
MT5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/mt5-small",
"google/mt5-base",
"google/mt5-large",
"google/mt5-xl",
"google/mt5-xxl",
# See all mT5 models at https://huggingface.co/models?filter=mt5
]

PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.

Expand Down Expand Up @@ -804,6 +818,10 @@ def _init_weights(self, module):
if hasattr(module, "qa_outputs"):
module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
module.qa_outputs.bias.data.zero_()
elif isinstance(module, MT5ForTokenClassification):
if hasattr(module, "classifier"):
module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.classifier.bias.data.zero_()
elif isinstance(module, MT5ClassificationHead):
module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.dense, "bias") and module.dense.bias is not None:
Expand Down Expand Up @@ -1334,7 +1352,6 @@ class MT5Model(MT5PreTrainedModel):

model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit surprised by this change, could we leave it as is? (for BC)

Copy link
Contributor Author

@hackyon hackyon Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this because of a test failure in test_model_weights_reload_no_missing_tied_weights():

self.assertEqual(

The reason it is failing is that "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight" is a key that will never exist in the model (for the current version of the code), which causes that test to fail. It's expecting the missing key to be in state of the reloaded model.

I believe we should be OK with backwards compatibility because the current version of the code is hard-coded not to have that key:

self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False)

WDYT? The other option would be to find a way to disable that test for the MT5model.

_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

Expand Down Expand Up @@ -2158,6 +2175,80 @@ def forward(
)


@add_start_docstrings(
"""
MT5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
e.g. for Named-Entity-Recognition (NER) tasks.
""",
MT5_START_DOCSTRING,
)
class MT5ForTokenClassification(MT5PreTrainedModel):
_tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]

# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5
def __init__(self, config: MT5Config):
super().__init__(config)
self.num_labels = config.num_labels

self.transformer = MT5EncoderModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits, outputs[2:-1])
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


@add_start_docstrings(
"""
MT5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"load_tf_weights_in_t5",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5ForTokenClassification",
]

try:
Expand Down Expand Up @@ -119,6 +120,7 @@
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5ForTokenClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
Expand Down
Loading