diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 6cb30a26f7d5..e31a9b7b80d0 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -482,29 +482,6 @@ def from_pretrained( " training." ) - # dictionary of key: dtypes for the model params - param_dtypes = jax.tree_map(lambda x: x.dtype, state) - # extract keys of parameters not in jnp.float32 - fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] - bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] - - # raise a warning if any of the parameters are not in jnp.float32 - if len(fp16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - - if len(bf16_params) > 0: - logger.warning( - f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " - f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" - "You should probably UPCAST the model weights to float32 if this was not intended. " - "See [`~ModelMixin.to_fp32`] for further information on how to do this." - ) - return model, unflatten_dict(state) def save_pretrained(