-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5-v1.1 loss go to nan when fp16 training was enabled #14189
Comments
Linked PR #10956 |
As suggested by Lysandre, @Liangtaiwan please check if this PR helps: #10956 |
@stas00 @patrickvonplaten @LysandreJik |
I am working with @HaokunLiu on a project that uses T5 and he found a great solution to this problem. The idea is to scale down the weights of the model in a specific pattern that maintains the relationship between the weights. I am not sure if this transformation is loss-preserving, but Here's his script
in
you need to add:
then in the
function you need the following lines here
|
Interesting, it seems we have similar ideas! My approach is slightly different, but seems to be working as well. Where yours scales down all the weights, mine aims to change the weights as little as possible. The weights to change are found using a search pattern (going through the encoder layers, then decoder layers), by scaling down the weights until it is able to infer and train without NaN. I have found changing the weights of the FFN in the last few encoder layers (about 3%-5% of the total model weights) is sufficient, and we can just scale it down by a factor of 2. At least on the model's existing pre-trained tasks, it still seems to be more or less still working, so I'm taking that as a good sign. I have also fine-tuned on my own task without NaN so far. (Tested t5-large and t5-3B) Example: https://github.com/tlkh/t5-fp16-surgery/blob/main/t5-3B.ipynb GitHub repo: https://github.com/tlkh/t5-fp16-surgery |
It is loss preserving. The last line |
@ibeltagy Thank you so much for sharing this! |
Environment info
I test in two different environments. One is my native env, one is nvidia container pytorch_21.09.
For more details, please refer https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel_21-09.html#rel_21-09
transformers
version: 4.11.3Who can help
@patrickvonplaten, @patil-suraj
Information
Model, I am using
t5-v1.1 (small, base)
with mix-precision, loss would go tonan
.The problem arises when using:
The tasks I am working on is:
The bug can be reproduced with run_summarization & run_summarization_no_trainer.py
To reproduce
Steps to reproduce the behavior:
1.❯
Both the following scrips can reproduce the results
nan
.(for Trainer, I print the loss before trainer.trainig_step return)
Possible Reason
In #10496, models clamp inf values only when
hidden_states.dtype == torch.float16.
However, even when fp16 training is enabled, the
hidden_states.dtype is still torch.float32
. This might be due to the layer_norm operation.Here are some more informations that might be useful to you.
When using BART and T5 with fp16 training, the
hidden_states.dtype is still torch.float32
, however; their loss won't go tonan
.The text was updated successfully, but these errors were encountered: