diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 92f713a97068..a1de9a1cdb8e 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -78,6 +78,16 @@ if TYPE_CHECKING: from . import PreTrainedTokenizerBase +logger = logging.get_logger(__name__) + +if "TF_USE_LEGACY_KERAS" not in os.environ: + os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2 +elif os.environ["TF_USE_LEGACY_KERAS"] != "1": + logger.warning( + "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " + "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models." + ) + try: import tf_keras as keras from tf_keras import backend as K @@ -93,7 +103,6 @@ ) -logger = logging.get_logger(__name__) tf_logger = tf.get_logger() TFModelInputType = Union[