Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5-v1.1 loss go to nan when fp16 training was enabled #14189

Closed
4 tasks done
Liangtaiwan opened this issue Oct 28, 2021 · 8 comments
Closed
4 tasks done

T5-v1.1 loss go to nan when fp16 training was enabled #14189

Liangtaiwan opened this issue Oct 28, 2021 · 8 comments

Comments

@Liangtaiwan
Copy link
Contributor

Liangtaiwan commented Oct 28, 2021

Environment info

I test in two different environments. One is my native env, one is nvidia container pytorch_21.09.
For more details, please refer https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel_21-09.html#rel_21-09

  • transformers version: 4.11.3
  • Platform: Arch Linux 5.14.14-arch1-1 (Ubuntu 20.04)
  • Python version: 3.9.7 (3.8)
  • PyTorch version (GPU?): 1.9.1 (1.10a)
  • Tensorflow version (GPU?): 2.6.0 (did not use)
  • Using GPU in script?: 2080Ti (V100)
  • Using distributed or parallel set-up in script?: using fp16

Who can help

@patrickvonplaten, @patil-suraj

Information

Model, I am using t5-v1.1 (small, base) with mix-precision, loss would go to nan.

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

The bug can be reproduced with run_summarization & run_summarization_no_trainer.py

To reproduce

Steps to reproduce the behavior:

1.❯
Both the following scrips can reproduce the results

python run_summarization.py \
    --fp16 --fp16_backend apex (both native amp & apex face thes same issue)\
    --model_name_or_path google/t5-v1_1-base \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=2 \
    --per_device_eval_batch_size=2 \
    --overwrite_output_dir \
accelerate launch --fp16 run_summarization_no_trainer.py \
    --model_name_or_path google/t5-v1_1-base \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --per_device_train_batch_size=2 \
    --output_dir ~/tmp/tst-summarization \
  1. If you print the loss step by step, you will find out loss goes to nan.
    (for Trainer, I print the loss before trainer.trainig_step return)

Possible Reason

In #10496, models clamp inf values only when hidden_states.dtype == torch.float16.
However, even when fp16 training is enabled, the hidden_states.dtype is still torch.float32. This might be due to the layer_norm operation.

Here are some more informations that might be useful to you.

When using BART and T5 with fp16 training, the hidden_states.dtype is still torch.float32, however; their loss won't go to nan.

@LysandreJik
Copy link
Member

LysandreJik commented Oct 28, 2021

Linked PR #10956

@patrickvonplaten
Copy link
Contributor

To be honest, I think we should just not do T5 training in fp16...cc @stas00

Related issues #10830

@stas00
Copy link
Contributor

stas00 commented Oct 29, 2021

As suggested by Lysandre, @Liangtaiwan please check if this PR helps: #10956

@Liangtaiwan
Copy link
Contributor Author

@stas00 @patrickvonplaten @LysandreJik
PR #10956 does prevent T5 from going nan and achieving a comparable result in fp32.
Close the issue and move to PR #10956 to discuss.

@ibeltagy
Copy link
Contributor

ibeltagy commented Nov 5, 2021

I am working with @HaokunLiu on a project that uses T5 and he found a great solution to this problem. The idea is to scale down the weights of the model in a specific pattern that maintains the relationship between the weights. I am not sure if this transformation is loss-preserving, but logits.argmax should remain the same.

Here's his script

import torch
from transformers import T5ForConditionalGeneration


emb_scaling = 1 / 32.0
att_v_scaling = 1 / 4.0
att_o_scaling = 1 / 8.0
ff_wi_scaling = 1 / 4.0
ff_wo_scaling = 1 / 4.0
ff_ln_scaling = 1 / 2.0

assert att_v_scaling * att_o_scaling == emb_scaling
assert ff_wi_scaling * ff_wo_scaling * ff_ln_scaling == emb_scaling

new_model = T5ForConditionalGeneration.from_pretrained('t5-base')
with torch.no_grad():
    new_model.shared.weight *= emb_scaling
    for unit in new_model.encoder.block:
        unit.layer[0].SelfAttention.v.weight *= att_v_scaling
        unit.layer[0].SelfAttention.o.weight *= att_o_scaling
        unit.layer[1].DenseReluDense.wi.weight *= ff_wi_scaling
        unit.layer[1].DenseReluDense.wo.weight *= ff_wo_scaling
        unit.layer[1].layer_norm.weight *= ff_ln_scaling
    for unit in new_model.decoder.block:
        unit.layer[0].SelfAttention.v.weight *= att_v_scaling
        unit.layer[0].SelfAttention.o.weight *= att_o_scaling
        unit.layer[1].EncDecAttention.v.weight *= att_v_scaling
        unit.layer[1].EncDecAttention.o.weight *= att_o_scaling
        unit.layer[2].DenseReluDense.wi.weight *= ff_wi_scaling
        unit.layer[2].DenseReluDense.wo.weight *= ff_wo_scaling
        unit.layer[2].layer_norm.weight *= ff_ln_scaling
    new_model.lm_scale_modifier /= emb_scaling

new_model.save_pretrained('t5-base-fp16-fixed')

in __init__

self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

you need to add:

self.lm_scale_modifier = nn.Parameter(torch.ones(config.d_model))

then in the forward

lm_logits = self.lm_head(sequence_output)

function you need the following lines here

sequence_output = sequence_output * self.lm_scale_modifier  # new code
lm_logits = self.lm_head(sequence_output)                   # existing code

@tlkh
Copy link
Contributor

tlkh commented Nov 5, 2021

@ibeltagy @HaokunLiu

Interesting, it seems we have similar ideas!

My approach is slightly different, but seems to be working as well. Where yours scales down all the weights, mine aims to change the weights as little as possible.

The weights to change are found using a search pattern (going through the encoder layers, then decoder layers), by scaling down the weights until it is able to infer and train without NaN. I have found changing the weights of the FFN in the last few encoder layers (about 3%-5% of the total model weights) is sufficient, and we can just scale it down by a factor of 2.

At least on the model's existing pre-trained tasks, it still seems to be more or less still working, so I'm taking that as a good sign. I have also fine-tuned on my own task without NaN so far. (Tested t5-large and t5-3B)

Example: https://github.com/tlkh/t5-fp16-surgery/blob/main/t5-3B.ipynb

GitHub repo: https://github.com/tlkh/t5-fp16-surgery

@ibeltagy
Copy link
Contributor

ibeltagy commented Nov 5, 2021

I am not sure if this transformation is loss-preserving

It is loss preserving. The last line new_model.lm_scale_modifier /= emb_scaling scales up the hidden states after the last layer (before lm_head) to counter the scaling down of the weights, thus keeping the transformation loss-preserving. This requires a small change in the T5 code to support lm_scale_modifier.

@yuvalkirstain
Copy link
Contributor

yuvalkirstain commented Feb 22, 2022

@ibeltagy Thank you so much for sharing this!
Did you by any chance check if those changes + applying fp16 while finetuning on a downstream task yield similar results as finetuning the vanilla model w/o fp16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants