-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF BERT not FP16 compatible? #3320
Comments
I've aced same issue. Maybe it's hard coded the data type somewhere? Have you found solution? |
Tried this on Colab TPU, same error. |
Same here, would be convenient as hell :) |
Having the same error also for #!/usr/bin/env python3
from transformers import TFBertModel, BertTokenizer
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
tok = BertTokenizer.from_pretrained("bert-base-uncased")
model = TFBertModel.from_pretrained("bert-base-uncased")
input_ids = tok("The dog is cute", return_tensors="tf").input_ids
model(input_ids) # throws error on GPU |
Encountering the same issue here: import tensorflow as tf
from transformers.modeling_tf_distilbert import TFDistilBertModel
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
model = TFDistilBertModel.from_pretrained('distilbert-base-uncased') |
Put this issue on my TF ToDo-List :-) |
+1 |
Hi @patrickvonplaten, is this problem fixed? |
This is still an open problem...I didn't find the time yet to take a look! Will link this issue to the TF projects. |
This is already solved in new version. |
🐛 Bug
Information
Model I am using (Bert, XLNet ...): TFBertForQuestionAnswering
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
The tasks I am working on is:
To reproduce
Simple example to reproduce error:
The error occurs here:
transformers/modeling_tf_bert.py", line 174, in _embedding
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
And this is the error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2] name: tf_bert_for_question_answering/bert/embeddings/add/
Expected behavior
I want to use TF BERT with mixed precision (for faster inference on tensor core GPUs). I know that full fp16 is not working out-of-the-box, because the model weights need to be in fp16 as well. Mixed precision, however, should work because only operations are performed in fp16.
I get some dtype issue. Seems the mode is not fp16 compatible yet? Will this be fixed in the future?
Environment info
transformers
version: 2.5.0The text was updated successfully, but these errors were encountered: