Skip to content

Commit

Permalink
separating mt5 exps with and without custom tokenizer; adding fp16 ca…
Browse files Browse the repository at this point in the history
  • Loading branch information
bri25yu committed Oct 2, 2022
1 parent 10b7a7f commit 6cbaa69
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 33 deletions.
116 changes: 83 additions & 33 deletions attention_driven/experiments/finetune_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,18 @@
from attention_driven.experiments.baseline_v2 import BaselineV2Experiment
from attention_driven.data_processors import LDTibetanEnglishDataV2Processor
from attention_driven.data_processors.utils import convert_df_to_hf_dataset
from attention_driven.modeling.mt5_fp16_utils import scale_weights_for_fp16_t5
from attention_driven.modeling.mt5_fp16 import MT5Fp16ForConditionalGeneration


__all__ = ["FinetuneMT5BaseExperiment", "FinetuneMT5LargeExperiment", "FinetuneMT5XLExperiment"]


# We use a special version fp16 capable version of MT5
class FinetuneMT5ExperimentBase(BaselineV2Experiment):
MODEL_NAME = None

def get_tokenizer(self) -> PreTrainedTokenizer:
"""
We don't train the tokenizer on Tibetan corpora at the moment, but this is probably something we want to do.
https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#t5-like-span-masked-language-modeling
"""

"""
tokenizer = AutoTokenizer.from_pretrained("buddhist-nlp/mt5-tibetan-tokenizer")
"""

model_name = self.MODEL_NAME

tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand All @@ -40,34 +33,15 @@ def get_model(self, tokenizer: PreTrainedTokenizer) -> PreTrainedModel:
model_name = self.MODEL_NAME
max_input_length = self.MAX_INPUT_LENGTH

"""
# Load pretrained parameter weights
base_model_parameter_dict = AutoModelForSeq2SeqLM.from_pretrained(model_name).state_dict()
base_model_parameter_dict = OrderedDict(base_model_parameter_dict) # Make `base_model_parameter_dict` modifiable
keys_to_modify = ["shared.weight", "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
pretrained_embedding_weights = {k: base_model_parameter_dict.pop(k) for k in keys_to_modify}
# Create new model
config = AutoConfig.from_pretrained(model_name, vocab_size=tokenizer.vocab_size + 2)
model = MT5ForConditionalGeneration(config)
# We don't have access to bf16 capable Ampere + GPUs so we need to workaround it
model = MT5Fp16ForConditionalGeneration.from_pretrained(model_name)
scale_weights_for_fp16_t5(model)

# Load pretrained weights into new model with a slight change to embeddings
# since we have a larger vocab size
model.load_state_dict(base_model_parameter_dict, strict=False)
model_parameter_dict = model.state_dict()
with torch.no_grad():
for weight_name, pretrained_embedding_weight in pretrained_embedding_weights.items():
pretrained_vocab_size, hidden_dim = pretrained_embedding_weight.size()
model_parameter_dict[weight_name][:pretrained_vocab_size, :hidden_dim].copy_(pretrained_embedding_weight)
"""

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.config.max_length = max_input_length

return model

# This is the exact same function as BaselineV2Experiment unless noted otherwise
# This is the exact same function as `BaselineV2Experiment.load_data` unless noted otherwise
def load_data(self, tokenizer: PreTrainedTokenizer) -> DatasetDict:
"""
This function assumes that https://github.com/Linguae-Dharmae/language-models
Expand All @@ -90,12 +64,37 @@ def load_data(self, tokenizer: PreTrainedTokenizer) -> DatasetDict:
print("Human readable dataset:", dataset)

def tokenize_fn(examples):

###########################
# START add mt5 prefix
###########################

# Original code
# model_inputs = tokenizer(examples["tibetan"], max_length=max_input_length, truncation=True)

prefix = "translate to english: "
tibetan_inputs = [prefix + t for t in examples["tibetan"]]
model_inputs = tokenizer(tibetan_inputs, max_length=max_input_length, truncation=True)

###########################
# END add mt5 prefix
###########################

###########################
# START use text_target rather than tokenizer target context
###########################

# Original code
# Set up the tokenizer for targets
# with tokenizer.as_target_tokenizer():
# labels = tokenizer(examples["english"], max_length=max_input_length, truncation=True)

labels = tokenizer(text_target=examples["english"], max_length=max_input_length, truncation=True)

###########################
# END use text_target rather than tokenizer target context
###########################

model_inputs["labels"] = labels["input_ids"]
return model_inputs

Expand All @@ -105,6 +104,45 @@ def tokenize_fn(examples):
return tokenized_dataset


# We add a custom Tibetan tokenizer in v2
class FinetuneMT5V2ExperimentBase(FinetuneMT5ExperimentBase):
def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained("buddhist-nlp/mt5-tibetan-tokenizer")

return tokenizer

def get_model(self, tokenizer: PreTrainedTokenizer) -> PreTrainedModel:
model_name = self.MODEL_NAME
max_input_length = self.MAX_INPUT_LENGTH

# Load pretrained parameter weights
base_model_parameter_dict = AutoModelForSeq2SeqLM.from_pretrained(model_name).state_dict()
base_model_parameter_dict = OrderedDict(base_model_parameter_dict) # Make `base_model_parameter_dict` modifiable

keys_to_modify = ["shared.weight", "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
pretrained_embedding_weights = {k: base_model_parameter_dict.pop(k) for k in keys_to_modify}

# Create new model
config = AutoConfig.from_pretrained(model_name, vocab_size=tokenizer.vocab_size + 2)
model = MT5Fp16ForConditionalGeneration(config)

# Load pretrained weights into new model with a slight change to embeddings
# since we have a larger vocab size
model.load_state_dict(base_model_parameter_dict, strict=False)
model_parameter_dict = model.state_dict()
with torch.no_grad():
for weight_name, pretrained_embedding_weight in pretrained_embedding_weights.items():
pretrained_vocab_size, hidden_dim = pretrained_embedding_weight.size()
model_parameter_dict[weight_name][:pretrained_vocab_size, :hidden_dim].copy_(pretrained_embedding_weight)

# We don't have access to bf16 capable Ampere + GPUs so we need to workaround it
scale_weights_for_fp16_t5(model)

model.config.max_length = max_input_length

return model


class FinetuneMT5BaseExperiment(FinetuneMT5ExperimentBase):
MODEL_NAME = "google/mt5-base"

Expand All @@ -115,3 +153,15 @@ class FinetuneMT5LargeExperiment(FinetuneMT5ExperimentBase):

class FinetuneMT5XLExperiment(FinetuneMT5ExperimentBase):
MODEL_NAME = "google/mt5-xl"


class FinetuneMT5BaseV2Experiment(FinetuneMT5V2ExperimentBase):
MODEL_NAME = "google/mt5-base"


class FinetuneMT5LargeV2Experiment(FinetuneMT5V2ExperimentBase):
MODEL_NAME = "google/mt5-large"


class FinetuneMT5XLV2Experiment(FinetuneMT5V2ExperimentBase):
MODEL_NAME = "google/mt5-xl"
160 changes: 160 additions & 0 deletions attention_driven/modeling/mt5_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
Adapted from https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628
"""
from typing import Tuple, Optional, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers import MT5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput


__all__ = ["MT5Fp16ForConditionalGeneration"]


class MT5Fp16ForConditionalGeneration(MT5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)

self.lm_scale_modifier = nn.Parameter(torch.ones(config.d_model))

# This is an exact copy of `T5ForConditionalGeneration.forward` unless specified otherwise
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

###############################
# START remove head_mask warning
###############################

# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
# if head_mask is not None and decoder_head_mask is None:
# if self.config.num_layers == self.config.num_decoder_layers:
# warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
# decoder_head_mask = head_mask

###############################
# END remove head_mask warning
###############################

# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)

hidden_states = encoder_outputs[0]

if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = decoder_outputs[0]

# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)

if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)

###############################
# START add lm_scale_modifier
###############################

sequence_output = sequence_output * self.lm_scale_modifier

###############################
# END add lm_scale_modifier
###############################

lm_logits = self.lm_head(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output

return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
55 changes: 55 additions & 0 deletions attention_driven/modeling/mt5_fp16_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Adapted from https://github.com/huggingface/transformers/issues/14189#issuecomment-961571628
"""
from typing import Union

import torch

from transformers import T5PreTrainedModel


__all__ = ["scale_weights_for_fp16_t5"]


# The same architectures in the model share the same weight scaling
# The embeddings scaling is the maximum scaling performed
ARCH_SCALING = {
"embeddings": 1 / 32.0,
"attention_value": 1 / 4.0,
"attention_output": 1 / 8.0,
"feedforward_weights_in": 1 / 4.0,
"feedforward_weights_out": 1 / 4.0,
"feedforward_layernorm": 1 / 2.0,
}

assert ARCH_SCALING["attention_value"] * ARCH_SCALING["attention_output"] == ARCH_SCALING["embeddings"]
assert ARCH_SCALING["feedforward_weights_in"] * ARCH_SCALING["feedforward_weights_out"] * ARCH_SCALING["feedforward_layernorm"] == ARCH_SCALING["embeddings"]


WEIGHT_SCALING = {
"shared.weight": ARCH_SCALING["embeddings"],
"SelfAttention.v": ARCH_SCALING["attention_value"],
"SelfAttention.o": ARCH_SCALING["attention_output"],
"EncDecAttention.v": ARCH_SCALING["attention_value"],
"EncDecAttention.o": ARCH_SCALING["attention_output"],
"lm_scale_modifier": 1 / ARCH_SCALING["embeddings"]
}


def scale_weights_for_fp16_t5(model: T5PreTrainedModel) -> None:
assert hasattr(model, "lm_scale_modifier"), "This function is only to be used with a modified mt5 model"

def search_for_scaling(weight_name: str) -> Union[None, float]:
for weight_name_infix in WEIGHT_SCALING:
if weight_name_infix in weight_name:
return WEIGHT_SCALING[weight_name_infix]

return None

with torch.no_grad():
for weight_name, weight in model.state_dict():
scaling = search_for_scaling(weight_name)
if scaling is None:
continue

weight *= scaling

0 comments on commit 6cbaa69

Please sign in to comment.