-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
separating mt5 exps with and without custom tokenizer; adding fp16 ca…
…pable mt5 from huggingface/transformers#14189 (comment)
- Loading branch information
Showing
3 changed files
with
298 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
""" | ||
Adapted from https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628 | ||
""" | ||
from typing import Tuple, Optional, Union | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import CrossEntropyLoss | ||
|
||
from transformers import MT5ForConditionalGeneration | ||
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput | ||
|
||
|
||
__all__ = ["MT5Fp16ForConditionalGeneration"] | ||
|
||
|
||
class MT5Fp16ForConditionalGeneration(MT5ForConditionalGeneration): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.lm_scale_modifier = nn.Parameter(torch.ones(config.d_model)) | ||
|
||
# This is an exact copy of `T5ForConditionalGeneration.forward` unless specified otherwise | ||
def forward( | ||
self, | ||
input_ids: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
decoder_input_ids: Optional[torch.LongTensor] = None, | ||
decoder_attention_mask: Optional[torch.BoolTensor] = None, | ||
head_mask: Optional[torch.FloatTensor] = None, | ||
decoder_head_mask: Optional[torch.FloatTensor] = None, | ||
cross_attn_head_mask: Optional[torch.Tensor] = None, | ||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: | ||
use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
############################### | ||
# START remove head_mask warning | ||
############################### | ||
|
||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | ||
# if head_mask is not None and decoder_head_mask is None: | ||
# if self.config.num_layers == self.config.num_decoder_layers: | ||
# warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | ||
# decoder_head_mask = head_mask | ||
|
||
############################### | ||
# END remove head_mask warning | ||
############################### | ||
|
||
# Encode if needed (training, first prediction pass) | ||
if encoder_outputs is None: | ||
# Convert encoder inputs in embeddings if needed | ||
encoder_outputs = self.encoder( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
inputs_embeds=inputs_embeds, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | ||
encoder_outputs = BaseModelOutput( | ||
last_hidden_state=encoder_outputs[0], | ||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | ||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | ||
) | ||
|
||
hidden_states = encoder_outputs[0] | ||
|
||
if self.model_parallel: | ||
torch.cuda.set_device(self.decoder.first_device) | ||
|
||
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: | ||
# get decoder inputs from shifting lm labels to the right | ||
decoder_input_ids = self._shift_right(labels) | ||
|
||
# Set device for model parallelism | ||
if self.model_parallel: | ||
torch.cuda.set_device(self.decoder.first_device) | ||
hidden_states = hidden_states.to(self.decoder.first_device) | ||
if decoder_input_ids is not None: | ||
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | ||
if attention_mask is not None: | ||
attention_mask = attention_mask.to(self.decoder.first_device) | ||
if decoder_attention_mask is not None: | ||
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) | ||
|
||
# Decode | ||
decoder_outputs = self.decoder( | ||
input_ids=decoder_input_ids, | ||
attention_mask=decoder_attention_mask, | ||
inputs_embeds=decoder_inputs_embeds, | ||
past_key_values=past_key_values, | ||
encoder_hidden_states=hidden_states, | ||
encoder_attention_mask=attention_mask, | ||
head_mask=decoder_head_mask, | ||
cross_attn_head_mask=cross_attn_head_mask, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
sequence_output = decoder_outputs[0] | ||
|
||
# Set device for model parallelism | ||
if self.model_parallel: | ||
torch.cuda.set_device(self.encoder.first_device) | ||
self.lm_head = self.lm_head.to(self.encoder.first_device) | ||
sequence_output = sequence_output.to(self.lm_head.weight.device) | ||
|
||
if self.config.tie_word_embeddings: | ||
# Rescale output before projecting on vocab | ||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | ||
sequence_output = sequence_output * (self.model_dim**-0.5) | ||
|
||
############################### | ||
# START add lm_scale_modifier | ||
############################### | ||
|
||
sequence_output = sequence_output * self.lm_scale_modifier | ||
|
||
############################### | ||
# END add lm_scale_modifier | ||
############################### | ||
|
||
lm_logits = self.lm_head(sequence_output) | ||
|
||
loss = None | ||
if labels is not None: | ||
loss_fct = CrossEntropyLoss(ignore_index=-100) | ||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) | ||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 | ||
|
||
if not return_dict: | ||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs | ||
return ((loss,) + output) if loss is not None else output | ||
|
||
return Seq2SeqLMOutput( | ||
loss=loss, | ||
logits=lm_logits, | ||
past_key_values=decoder_outputs.past_key_values, | ||
decoder_hidden_states=decoder_outputs.hidden_states, | ||
decoder_attentions=decoder_outputs.attentions, | ||
cross_attentions=decoder_outputs.cross_attentions, | ||
encoder_last_hidden_state=encoder_outputs.last_hidden_state, | ||
encoder_hidden_states=encoder_outputs.hidden_states, | ||
encoder_attentions=encoder_outputs.attentions, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Adapted from https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628 | ||
""" | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from transformers import T5PreTrainedModel | ||
|
||
|
||
__all__ = ["scale_weights_for_fp16_t5"] | ||
|
||
|
||
# The same architectures in the model share the same weight scaling | ||
# The embeddings scaling is the maximum scaling performed | ||
ARCH_SCALING = { | ||
"embeddings": 1 / 32.0, | ||
"attention_value": 1 / 4.0, | ||
"attention_output": 1 / 8.0, | ||
"feedforward_weights_in": 1 / 4.0, | ||
"feedforward_weights_out": 1 / 4.0, | ||
"feedforward_layernorm": 1 / 2.0, | ||
} | ||
|
||
assert ARCH_SCALING["attention_value"] * ARCH_SCALING["attention_output"] == ARCH_SCALING["embeddings"] | ||
assert ARCH_SCALING["feedforward_weights_in"] * ARCH_SCALING["feedforward_weights_out"] * ARCH_SCALING["feedforward_layernorm"] == ARCH_SCALING["embeddings"] | ||
|
||
|
||
WEIGHT_SCALING = { | ||
"shared.weight": ARCH_SCALING["embeddings"], | ||
"SelfAttention.v": ARCH_SCALING["attention_value"], | ||
"SelfAttention.o": ARCH_SCALING["attention_output"], | ||
"EncDecAttention.v": ARCH_SCALING["attention_value"], | ||
"EncDecAttention.o": ARCH_SCALING["attention_output"], | ||
"lm_scale_modifier": 1 / ARCH_SCALING["embeddings"] | ||
} | ||
|
||
|
||
def scale_weights_for_fp16_t5(model: T5PreTrainedModel) -> None: | ||
assert hasattr(model, "lm_scale_modifier"), "This function is only to be used with a modified mt5 model" | ||
|
||
def search_for_scaling(weight_name: str) -> Union[None, float]: | ||
for weight_name_infix in WEIGHT_SCALING: | ||
if weight_name_infix in weight_name: | ||
return WEIGHT_SCALING[weight_name_infix] | ||
|
||
return None | ||
|
||
with torch.no_grad(): | ||
for weight_name, weight in model.state_dict(): | ||
scaling = search_for_scaling(weight_name) | ||
if scaling is None: | ||
continue | ||
|
||
weight *= scaling |