You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running without FP16, the model trains as expected. Other models that I have tested did not have this issue and converge well with fp16 enabled: RoBERTa, BERT, and DistilBERT.
The text was updated successfully, but these errors were encountered:
It looks like mobileBERT was pretrained on TPUs using bfloat16, which then often result in NaNs when using FP16 for further fine-tuning (see #11076 or #10956). You'll be best off training in FP32 or use another model compatible with FP16.
Makes sense! That's interesting that affects the training on GPUs! I will pass this info on to my colleague who deals with reproducibility! And for now I shall stick with FP32 when fine-tuning the MobileBERT model!
You'll be best off training in FP32 or use another model compatible with FP16.
And at some point we should also add --bf16 mode to Trainer, for those who want to do finetuning and inference on hardware that supports it . e.g. high-end Ampere RTX-3090 and A100 should already support it, and of course TPU v2+.
Environment info
transformers
version: 4.5.1Who can help
@sgugger @stas00 @patil-suraj
Information
Model I am using MobileBERT:
The problem arises when using:
The tasks I am working on is:
To reproduce
Using the example: https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py
Steps to reproduce the behavior:
3.loss will return nan
First observed nans popping up from the encoder within the forward call in the MobileBertModel class:
https://huggingface.co/transformers/_modules/transformers/modeling_mobilebert.html
Expected behavior
When running without FP16, the model trains as expected. Other models that I have tested did not have this issue and converge well with fp16 enabled: RoBERTa, BERT, and DistilBERT.
The text was updated successfully, but these errors were encountered: