-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overhaul TF serving signatures + dummy inputs #23234
Changes from all commits
f5df9cb
e3338c0
459e2c6
507ec1f
6435f3b
f379e89
0ed9bc0
b060e7c
9d8fd7d
e27e490
db4f20b
8543596
b5f3b31
79f4b50
e0ec348
831d56d
9670730
7da8c8f
1b0f380
b867d7d
5a3bb9c
526821c
7791755
ef2191a
e43ae84
2cded64
ddea02e
26f4f67
f61f6d3
4328bf7
2b4baf4
bb73138
74e991d
38dbc86
057454f
692921e
1b90dad
3bfe4c3
0efc89b
1ff2c38
257f3de
599ce59
1a9d3dd
f348dbe
cd887ef
60a41fa
05f7584
52e5b6b
99da521
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,6 @@ | |
from .generation import GenerationConfig, TFGenerationMixin | ||
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list | ||
from .utils import ( | ||
DUMMY_INPUTS, | ||
SAFE_WEIGHTS_INDEX_NAME, | ||
SAFE_WEIGHTS_NAME, | ||
TF2_WEIGHTS_INDEX_NAME, | ||
|
@@ -1114,9 +1113,25 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: | |
Returns: | ||
`Dict[str, tf.Tensor]`: The dummy inputs. | ||
""" | ||
return { | ||
"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32), | ||
} | ||
dummies = {} | ||
sig = self._prune_signature(self.input_signature) | ||
for key, spec in sig.items(): | ||
# 3 is the most correct arbitrary size. I will not be taking questions | ||
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype) | ||
if key == "token_type_ids": | ||
# Some models have token_type_ids but with a vocab_size of 1 | ||
dummies[key] = tf.zeros_like(dummies[key]) | ||
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters: | ||
if "encoder_hidden_states" not in dummies: | ||
if self.main_input_name == "input_ids": | ||
dummies["encoder_hidden_states"] = tf.ones( | ||
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states" | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!" | ||
) | ||
return dummies | ||
|
||
@property | ||
def framework(self) -> str: | ||
|
@@ -1137,6 +1152,10 @@ def __init__(self, config, *inputs, **kwargs): | |
self.config = config | ||
self.name_or_path = config.name_or_path | ||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None | ||
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures | ||
self.serving = tf.function( | ||
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)] | ||
) | ||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec | ||
self._set_save_spec(self.serving.input_signature[0]) | ||
|
||
|
@@ -1201,36 +1220,82 @@ def eager_serving(self, inputs): | |
|
||
return self.serving_output(output) | ||
|
||
@tf.function( | ||
input_signature=[ | ||
{ | ||
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), | ||
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), | ||
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"), | ||
} | ||
] | ||
) | ||
def serving(self, inputs): | ||
@property | ||
def input_signature(self) -> Dict[str, tf.TensorSpec]: | ||
""" | ||
Method used for serving the model. | ||
Args: | ||
inputs (`Dict[str, tf.Tensor]`): | ||
The input of the saved model as a dictionary of tensors. | ||
This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected | ||
shape and dtype for model inputs. It is used for both serving and for generating the dummy inputs used to build | ||
the model. | ||
""" | ||
output = self.call(inputs) | ||
model_inputs = list(inspect.signature(self.call).parameters) | ||
sig = {} | ||
if "input_ids" in model_inputs: | ||
if self.__class__.__name__.endswith("ForMultipleChoice"): | ||
text_dims = 3 | ||
else: | ||
text_dims = 2 | ||
for input_name in ( | ||
"input_ids", | ||
"attention_mask", | ||
"token_type_ids", | ||
"decoder_input_ids", | ||
"decoder_attention_mask", | ||
): | ||
if input_name in model_inputs: | ||
sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name) | ||
if "pixel_values" in model_inputs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One edge-case to consider here is video models which will have an additional dimension There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in those cases we can override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think overriding is good and we can revisit when there's more video models added 👍 |
||
pixel_values_shape = [None, None, None, None] | ||
if hasattr(self.config, "vision_config"): | ||
vision_config = self.config.vision_config | ||
else: | ||
vision_config = self.config | ||
if hasattr(vision_config, "num_channels"): | ||
pixel_values_shape[1] = vision_config.num_channels | ||
else: | ||
raise NotImplementedError( | ||
"Could not infer number of channels from config, please override input_signature to specify input shapes." | ||
) | ||
if hasattr(vision_config, "image_size"): | ||
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size | ||
elif hasattr(vision_config, "input_size"): | ||
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size | ||
else: | ||
raise NotImplementedError( | ||
"Could not infer input image shape from config, please override input_signature to specify input shapes." | ||
) | ||
sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values") | ||
if "input_features" in model_inputs: | ||
raise NotImplementedError("Audio models need a manually defined input_signature") | ||
return sig | ||
|
||
return self.serving_output(output) | ||
def _prune_signature(self, signature): | ||
"""Keeps only the keys of a given input signature that are valid for this model.""" | ||
model_inputs = list(inspect.signature(self.call).parameters) | ||
return {key: val for key, val in signature.items() if key in model_inputs} | ||
|
||
def serving_output(self, output): | ||
""" | ||
Prepare the output of the saved model. Each model must implement this function. | ||
Args: | ||
output ([`TFBaseModelOutput`]): | ||
The output returned by the model. | ||
""" | ||
raise NotImplementedError | ||
Prepare the output of the saved model. Can be overridden if specific serving modifications are required. | ||
""" | ||
if not isinstance(output, ModelOutput): | ||
return output | ||
for key in output: | ||
if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False): | ||
output[key] = None | ||
elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False): | ||
output[key] = None | ||
elif key == "past_key_values" and not getattr(self.config, "use_cache", False): | ||
output[key] = None | ||
elif key == "cross_attentions" and not ( | ||
getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False) | ||
): | ||
output[key] = None | ||
if isinstance(output[key], (tuple, list)): | ||
try: | ||
output[key] = tf.convert_to_tensor(output[key]) | ||
except (ValueError, tf.errors.InvalidArgumentError): | ||
pass # Layers may not have the same dimensions | ||
return output | ||
|
||
def can_generate(self) -> bool: | ||
""" | ||
|
@@ -1384,7 +1449,7 @@ def prepare_tf_dataset( | |
|
||
if not isinstance(dataset, datasets.Dataset): | ||
raise TypeError("Dataset argument should be a datasets.Dataset!") | ||
model_inputs = list(dict(inspect.signature(self.call).parameters).keys()) | ||
model_inputs = list(inspect.signature(self.call).parameters) | ||
model_labels = find_labels(self.__class__) | ||
if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()): | ||
output_signature, _ = dataset._get_output_signature( | ||
|
@@ -1496,7 +1561,7 @@ def compute_loss(self, *args, **kwargs): | |
return self.hf_compute_loss(*args, **kwargs) | ||
|
||
def get_label_to_output_name_mapping(self): | ||
arg_names = list(dict(inspect.signature(self.call).parameters).keys()) | ||
arg_names = list(inspect.signature(self.call).parameters) | ||
if self._label_to_output_map is not None: | ||
return self._label_to_output_map | ||
elif "start_positions" in arg_names: | ||
|
@@ -1519,7 +1584,7 @@ def train_step(self, data): | |
""" | ||
|
||
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` | ||
arg_names = list(dict(inspect.signature(self.call).parameters).keys()) | ||
arg_names = list(inspect.signature(self.call).parameters) | ||
label_kwargs = find_labels(self.__class__) | ||
label_to_output = self.get_label_to_output_name_mapping() | ||
output_to_label = {val: key for key, val in label_to_output.items()} | ||
|
@@ -1626,7 +1691,7 @@ def test_step(self, data): | |
that they are available to the model during the forward pass. | ||
""" | ||
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map` | ||
arg_names = list(dict(inspect.signature(self.call).parameters).keys()) | ||
arg_names = list(inspect.signature(self.call).parameters) | ||
label_kwargs = find_labels(self.__class__) | ||
label_to_output = self.get_label_to_output_name_mapping() | ||
output_to_label = {val: key for key, val in label_to_output.items()} | ||
|
@@ -1645,7 +1710,7 @@ def test_step(self, data): | |
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments, | ||
# if those keys are not already present in the input dict | ||
if self._using_dummy_loss and y is not None: | ||
arg_names = list(dict(inspect.signature(self.call).parameters).keys()) | ||
arg_names = list(inspect.signature(self.call).parameters) | ||
# If y is a tensor and the model only has one label-like input, map y to that input | ||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor): | ||
if isinstance(x, tf.Tensor): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a
if "decoder_input_ids" in model_inputs:
case? It would automate the encoder-decoder exceptions :DThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this, but there are probably situations where we don't want it in the dummies even if it is a valid input for the model, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me try it and see what fails
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...nothing failed. Huh. Let me see what dummies I can remove!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed a bunch of redundant signatures with no issues. Thanks for the suggestion!