Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast bfloat16 to float32 for Numpy conversions #29755

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def load_pytorch_weights_in_tf2_model(
)
raise

pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
pt_state_dict = {
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
return load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFAutoModel,
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
Expand Down Expand Up @@ -435,6 +436,16 @@ def test_safetensors_checkpoint_sharding_local(self):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))

@is_pt_tf_cross_test
@require_safetensors
def test_bfloat16_torch_loading(self):
# Assert that neither of these raise an error - both repos contain bfloat16 tensors
model1 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16-pt", from_pt=True)
model2 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16") # PT-format safetensors
# Check that PT and safetensors loading paths end up with the same values
for weight1, weight2 in zip(model1.weights, model2.weights):
self.assertTrue(tf.reduce_all(weight1 == weight2))

@slow
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
Expand Down
Loading