Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul TF serving signatures + dummy inputs #23234

Merged
merged 49 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f5df9cb
Let's try autodetecting serving sigs
Rocketknight1 May 9, 2023
e3338c0
Don't clobber existing sigs
Rocketknight1 May 9, 2023
459e2c6
Change shapes for multiplechoice models
Rocketknight1 May 9, 2023
507ec1f
Make default dummy inputs smarter too
Rocketknight1 May 9, 2023
6435f3b
Fix missing f-string
Rocketknight1 May 9, 2023
f379e89
Let's YOLO a serving output too
Rocketknight1 May 9, 2023
0ed9bc0
Read __class__.__name__ properly
Rocketknight1 May 9, 2023
b060e7c
Don't just pass naked lists in there and expect it to be okay
Rocketknight1 May 9, 2023
9d8fd7d
Code cleanup
Rocketknight1 May 10, 2023
e27e490
Update default serving sig
Rocketknight1 May 10, 2023
db4f20b
Clearer error messages
Rocketknight1 May 10, 2023
8543596
Further updates to the default serving output
Rocketknight1 May 10, 2023
b5f3b31
make fixup
Rocketknight1 May 10, 2023
79f4b50
Update the serving output a bit more
Rocketknight1 May 10, 2023
e0ec348
Cleanups and renames, raise errors appropriately when we can't infer …
Rocketknight1 May 12, 2023
831d56d
More renames
Rocketknight1 May 12, 2023
9670730
we're building in a functional context again, yolo
Rocketknight1 May 12, 2023
7da8c8f
import DUMMY_INPUTS from the right place
Rocketknight1 May 12, 2023
1b0f380
import DUMMY_INPUTS from the right place
Rocketknight1 May 12, 2023
b867d7d
Support cross-attention in the dummies
Rocketknight1 May 12, 2023
5a3bb9c
Support cross-attention in the dummies
Rocketknight1 May 12, 2023
526821c
Complete removal of dummy/serving overrides in BERT
Rocketknight1 May 12, 2023
7791755
Complete removal of dummy/serving overrides in RoBERTa
Rocketknight1 May 12, 2023
ef2191a
Obliterate lots and lots of serving sig and dummy overrides
Rocketknight1 May 12, 2023
e43ae84
merge type hint changes
Rocketknight1 May 12, 2023
2cded64
Fix for token_type_ids with vocab_size 1
Rocketknight1 May 12, 2023
ddea02e
Add missing property decorator
Rocketknight1 May 12, 2023
26f4f67
Fix T5 and hopefully some models that take conv inputs
Rocketknight1 May 15, 2023
f61f6d3
More signature pruning
Rocketknight1 May 15, 2023
4328bf7
Fix T5's signature
Rocketknight1 May 15, 2023
2b4baf4
Fix Wav2Vec2 signature
Rocketknight1 May 15, 2023
bb73138
Fix LongformerForMultipleChoice input signature
Rocketknight1 May 15, 2023
74e991d
Fix BLIP and LED
Rocketknight1 May 15, 2023
38dbc86
Better default serving output error handling
Rocketknight1 May 16, 2023
057454f
Fix BART dummies
Rocketknight1 May 17, 2023
692921e
Fix dummies for cross-attention, esp encoder-decoder models
Rocketknight1 May 17, 2023
1b90dad
Fix visionencoderdecoder signature
Rocketknight1 May 17, 2023
3bfe4c3
Fix BLIP serving output
Rocketknight1 May 18, 2023
0efc89b
Small tweak to BART dummies
Rocketknight1 May 18, 2023
1ff2c38
Cleanup the ugly parameter inspection line that I used in a few places
Rocketknight1 May 18, 2023
257f3de
committed a breakpoint again
Rocketknight1 May 18, 2023
599ce59
Move the text_dims check
Rocketknight1 May 18, 2023
1a9d3dd
Remove blip_text serving_output
Rocketknight1 May 18, 2023
f348dbe
Add decoder_input_ids to the default input sig
Rocketknight1 May 18, 2023
cd887ef
Remove all the manual overrides for encoder-decoder model signatures
Rocketknight1 May 18, 2023
60a41fa
Tweak longformer/led input sigs
Rocketknight1 May 24, 2023
05f7584
Tweak default serving output
Rocketknight1 May 24, 2023
52e5b6b
output.keys() -> output
Rocketknight1 May 24, 2023
99da521
make fixup
Rocketknight1 May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 98 additions & 33 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
from .utils import (
DUMMY_INPUTS,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -1114,9 +1113,25 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
return {
"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32),
}
dummies = {}
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 3 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(
"Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
)
return dummies

@property
def framework(self) -> str:
Expand All @@ -1137,6 +1152,10 @@ def __init__(self, config, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
self.serving = tf.function(
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)]
)
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])

Expand Down Expand Up @@ -1201,36 +1220,82 @@ def eager_serving(self, inputs):

return self.serving_output(output)

@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
@property
def input_signature(self) -> Dict[str, tf.TensorSpec]:
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
shape and dtype for model inputs. It is used for both serving and for generating the dummy inputs used to build
the model.
"""
output = self.call(inputs)
model_inputs = list(inspect.signature(self.call).parameters)
sig = {}
if "input_ids" in model_inputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a if "decoder_input_ids" in model_inputs: case? It would automate the encoder-decoder exceptions :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this, but there are probably situations where we don't want it in the dummies even if it is a valid input for the model, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try it and see what fails

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...nothing failed. Huh. Let me see what dummies I can remove!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed a bunch of redundant signatures with no issues. Thanks for the suggestion!

if self.__class__.__name__.endswith("ForMultipleChoice"):
text_dims = 3
else:
text_dims = 2
for input_name in (
"input_ids",
"attention_mask",
"token_type_ids",
"decoder_input_ids",
"decoder_attention_mask",
):
if input_name in model_inputs:
sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
if "pixel_values" in model_inputs:
Copy link
Collaborator

@amyeroberts amyeroberts May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One edge-case to consider here is video models which will have an additional dimension

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in those cases we can override input_signature, but maybe there's some way to adjust this function to figure it out automatically.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overriding is good and we can revisit when there's more video models added 👍

pixel_values_shape = [None, None, None, None]
if hasattr(self.config, "vision_config"):
vision_config = self.config.vision_config
else:
vision_config = self.config
if hasattr(vision_config, "num_channels"):
pixel_values_shape[1] = vision_config.num_channels
else:
raise NotImplementedError(
"Could not infer number of channels from config, please override input_signature to specify input shapes."
)
if hasattr(vision_config, "image_size"):
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
elif hasattr(vision_config, "input_size"):
pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
else:
raise NotImplementedError(
"Could not infer input image shape from config, please override input_signature to specify input shapes."
)
sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
if "input_features" in model_inputs:
raise NotImplementedError("Audio models need a manually defined input_signature")
return sig

return self.serving_output(output)
def _prune_signature(self, signature):
"""Keeps only the keys of a given input signature that are valid for this model."""
model_inputs = list(inspect.signature(self.call).parameters)
return {key: val for key, val in signature.items() if key in model_inputs}

def serving_output(self, output):
"""
Prepare the output of the saved model. Each model must implement this function.
Args:
output ([`TFBaseModelOutput`]):
The output returned by the model.
"""
raise NotImplementedError
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
"""
if not isinstance(output, ModelOutput):
return output
for key in output:
if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
output[key] = None
elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
output[key] = None
elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
output[key] = None
elif key == "cross_attentions" and not (
getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
):
output[key] = None
if isinstance(output[key], (tuple, list)):
try:
output[key] = tf.convert_to_tensor(output[key])
except (ValueError, tf.errors.InvalidArgumentError):
pass # Layers may not have the same dimensions
return output

def can_generate(self) -> bool:
"""
Expand Down Expand Up @@ -1384,7 +1449,7 @@ def prepare_tf_dataset(

if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_inputs = list(inspect.signature(self.call).parameters)
model_labels = find_labels(self.__class__)
if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
output_signature, _ = dataset._get_output_signature(
Expand Down Expand Up @@ -1496,7 +1561,7 @@ def compute_loss(self, *args, **kwargs):
return self.hf_compute_loss(*args, **kwargs)

def get_label_to_output_name_mapping(self):
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
if self._label_to_output_map is not None:
return self._label_to_output_map
elif "start_positions" in arg_names:
Expand All @@ -1519,7 +1584,7 @@ def train_step(self, data):
"""

# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand Down Expand Up @@ -1626,7 +1691,7 @@ def test_step(self, data):
that they are available to the model during the forward pass.
"""
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
label_kwargs = find_labels(self.__class__)
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand All @@ -1645,7 +1710,7 @@ def test_step(self, data):
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
arg_names = list(inspect.signature(self.call).parameters)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
Expand Down
85 changes: 0 additions & 85 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -826,17 +825,6 @@ def call(

return outputs

def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -933,17 +921,6 @@ def call(
attentions=outputs.attentions,
)

def serving_output(self, output: TFAlbertForPreTrainingOutput) -> TFAlbertForPreTrainingOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFAlbertForPreTrainingOutput(
prediction_logits=output.prediction_logits,
sop_logits=output.sop_logits,
hidden_states=hs,
attentions=attns,
)


class TFAlbertSOPHead(tf.keras.layers.Layer):
def __init__(self, config: AlbertConfig, **kwargs):
Expand Down Expand Up @@ -1058,13 +1035,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1147,13 +1117,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1237,13 +1200,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1339,15 +1295,6 @@ def call(
attentions=outputs.attentions,
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
)


@add_start_docstrings(
"""
Expand All @@ -1370,16 +1317,6 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)

@property
def dummy_inputs(self):
"""
Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32)}

@unpack_inputs
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1457,25 +1394,3 @@ def call(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}
]
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]) -> TFMultipleChoiceModelOutput:
output = self.call(input_ids=inputs)

return self.serving_output(output)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
Loading