diff --git a/src/transformers/tf_utils.py b/src/transformers/tf_utils.py index 0ee6c4724102..75e302947e80 100644 --- a/src/transformers/tf_utils.py +++ b/src/transformers/tf_utils.py @@ -17,6 +17,7 @@ import numpy as np import tensorflow as tf +from .feature_extraction_utils import BatchFeature from .tokenization_utils_base import BatchEncoding from .utils import logging @@ -257,10 +258,10 @@ def _expand_single_1d_tensor(t): def convert_batch_encoding(*args, **kwargs): - # Convert HF BatchEncoding objects in the inputs to dicts that Keras understands - if args and isinstance(args[0], BatchEncoding): + # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands + if args and isinstance(args[0], (BatchEncoding, BatchFeature)): args = list(args) args[0] = dict(args[0]) - elif "x" in kwargs and isinstance(kwargs["x"], BatchEncoding): + elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)): kwargs["x"] = dict(kwargs["x"]) return args, kwargs