Skip to content

Fine-Tuning the Multilingual Text-To-Text Transfer Transformer (mT5) on XNLI for Language Classification.

Notifications You must be signed in to change notification settings

Ali1858/mt5_lang-classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fine-Tuning the Multilingual Text-To-Text Transfer Transformer (mT5) on XNLI for Language Classification

  • This repository demostrate how we can do fine-tuning on pretrained multi-lingual text-to-text mT5 model. The mT5 model is trained on 101 language.
  • We will be using the Cross-lingual Natural Language Inference (XNLI) dataset for fine-tuning. We are going to use sentence1 and sentence2 colomns as input and language as target.
  • XNLI dataset can be download here
  • We are going to use xnli.test.tsv dataset for cross-validation and once we have selected our hyper-parameters we will train our model on full xnli.test.tsv dataset and evaluated on xnli.dev.tsv. We are not going to use xnli.dev.tsv dataset untill we have choosen all the hyper-parameter.
  • Reference; I have refered this official notebook from t5 for fine-tuning the model.

Experimenting with different hyper parameter.

  1. First; We tried different Batch size (32, 128, 256) and concluded that batch size of 128 and 256 perform same as long as number of training steps are same
  2. Second; We tried different Epoch (10, 8, 6, 3, 2) and concluded that epoch depends on batch size and Learning rate. If Batch is small (eg: 128) epoch can also be small and for higher batch Epoch should be higher. If LR is higher we can use small epoch.
  3. Third; We tried different LR (0.0001, 0.001, 0.003) and concluded with higher learning rate model will train faster and we will need small epoch.

After running multiple iteration, we finally selected following hyper-parameter

  1. Batch Size = 128
  2. EPOCH = 2
  3. LR = 0.003
  • We train the model on full dataset using these hyper-parameter. Below are the metrics on evaluation set.

metrics

Prediction on unseen data from web with 100% accuracy.

prediction

Observation.

  1. We have observed that model misclassify Hindi and Urdu Language. The wrong prediction are mostly on input sentence which are written using English alphabets, not Hindi or Urdu alphabets.
  2. Model also misclassify Bulgarian and Russian. The wrong predictions are possibily because both Russian and Bulgarian belong from same language family.
  3. Few sentences are incorrectly labeled as english.
  4. Some input sentences are very small and because of that wrong prediction. Since training dataset size is decent we can eliminated such small sentences.

How to Run this notebook.

  • Its highly recommended to Run this notebook using Collab TPU Runtime env. I have already shared the Collab notebook here, Please Run it directly.
  • To Change the runtime please follow; Runtime --> change runtime type --> TPU If fine-tuning the model then GPU is also fine.
  • All the data and fine-tune checkpoints are already available in my GCS bucket*****
  • To get the prediction on unseen data using best model, just run all the cell as it is. You can also input your own sentences.
  • Notebook can also be run on local machine by install pip install -r requirements.txt. However all the data, checkpoints and output should be saved on GCS bucket only.

P.S:

  • GCS bucket is publicly available for limited time and after that those links will expire. Please update those links with your own GCS link and also upload training data.
  • Know Issue: For some reason t5 library is not able to update the checkpoint's *.gin confile file and because of that prediction function throws error. I have manually updated the *.gin config file for best model for getting the predictions.

About

Fine-Tuning the Multilingual Text-To-Text Transfer Transformer (mT5) on XNLI for Language Classification.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published