Skip to content

Commit

Permalink
[Flax] dont warn for bf16 weights (huggingface#923)
Browse files Browse the repository at this point in the history
dont warn for bf16 weights
  • Loading branch information
patil-suraj authored Oct 21, 2022
1 parent ffc5dd6 commit 3eb83bc
Showing 1 changed file with 0 additions and 23 deletions.
23 changes: 0 additions & 23 deletions modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,29 +482,6 @@ def from_pretrained(
" training."
)

# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]

# raise a warning if any of the parameters are not in jnp.float32
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)

if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)

return model, unflatten_dict(state)

def save_pretrained(
Expand Down

0 comments on commit 3eb83bc

Please sign in to comment.