Skip to content

Commit

Permalink
Save output schema of model and add save method to Encoder (#886)
Browse files Browse the repository at this point in the history
* Add save method to `Encoder`

* Save output schema when calling `model.save`

* Add parameters to save method docstring

* Only replace inferred schema if it's not already defined

* Remove properties and set is_ragged based on second dimension.

* Use getattr to check for schema in BaseModel.fit

Co-authored-by: rnyak <ronayak@hotmail.com>
  • Loading branch information
oliverholworthy and rnyak authored Dec 1, 2022
1 parent 5e4d8a1 commit e08a72c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
34 changes: 33 additions & 1 deletion merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
# limitations under the License.
#

import os
from typing import Dict, Optional, Union

import numpy as np
import tensorflow as tf
from packaging import version

import merlin.io
from merlin.models.io import save_merlin_metadata
from merlin.models.tf.core import combinators
from merlin.models.tf.core.prediction import TopKPrediction
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
from merlin.models.tf.models.base import BaseModel
from merlin.models.tf.models.base import BaseModel, get_output_schema
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.utils import tf_utils
from merlin.schema import ColumnSchema, Schema, Tags
Expand Down Expand Up @@ -206,6 +208,36 @@ def _set_save_spec(self, inputs, args=None, kwargs=None):
_arg_spec = self._saved_model_arg_spec
self._saved_model_arg_spec = ([_arg_spec[0][0]], _arg_spec[1])

def save(
self,
export_path: Union[str, os.PathLike],
include_optimizer=True,
save_traces=True,
) -> None:
"""Saves the model to export_path as a Tensorflow Saved Model.
Along with merlin model metadata.
Parameters
----------
export_path : Union[str, os.PathLike]
Path where model will be saved to
include_optimizer : bool, optional
If False, do not save the optimizer state, by default True
save_traces : bool, optional
When enabled, will store the function traces for each layer. This
can be disabled, so that only the configs of each layer are
stored, by default True
"""
super().save(
export_path,
include_optimizer=include_optimizer,
save_traces=save_traces,
save_format="tf",
)
input_schema = self.schema
output_schema = get_output_schema(export_path)
save_merlin_metadata(export_path, self, input_schema, output_schema)

@property
def to_call(self):
if self.pre:
Expand Down
52 changes: 49 additions & 3 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,39 @@ def on_train_batch_end(self, batch, logs=None):
self._is_first_batch = False


def get_output_schema(export_path: str) -> Schema:
"""Compute Output Schema
Parameters
----------
export_path : str
Path to saved model directory
Returns
-------
Schema
Output Schema representing model outputs
"""
model = tf.keras.models.load_model(export_path)
signature = model.signatures["serving_default"]

output_schema = Schema()
for output_name, output_spec in signature.structured_outputs.items():
col_schema = ColumnSchema(output_name, dtype=output_spec.dtype.as_numpy_dtype)
shape = output_spec.shape
if shape.rank > 1 and (shape[1] is None or shape[1] > 1):
is_ragged = shape[1] is None
col_schema = ColumnSchema(
output_name,
dtype=output_spec.dtype.as_numpy_dtype,
is_list=True,
is_ragged=is_ragged,
)
output_schema.column_schemas[output_name] = col_schema

return output_schema


@tf.keras.utils.register_keras_serializable(package="merlin_models")
class ModelBlock(Block, tf.keras.Model):
"""Block that extends `tf.keras.Model` to make it saveable."""
Expand Down Expand Up @@ -882,8 +915,8 @@ def fit(
x = _maybe_convert_merlin_dataset(x, batch_size, **kwargs)

# Bind schema from dataset to model in case we can't infer it from the inputs
if isinstance(x, Loader):
self.schema = x.schema
if isinstance(x, Loader) and getattr(self, "schema", None) is None:
self.schema = x.schema.excluding_by_tag(Tags.TARGET)

validation_data = _maybe_convert_merlin_dataset(
validation_data, batch_size, shuffle=shuffle, **kwargs
Expand Down Expand Up @@ -1107,14 +1140,27 @@ def save(
) -> None:
"""Saves the model to export_path as a Tensorflow Saved Model.
Along with merlin model metadata.
Parameters
----------
export_path : Union[str, os.PathLike]
Path where model will be saved to
include_optimizer : bool, optional
If False, do not save the optimizer state, by default True
save_traces : bool, optional
When enabled, will store the function traces for each layer. This
can be disabled, so that only the configs of each layer are
stored, by default True
"""
super().save(
export_path,
include_optimizer=include_optimizer,
save_traces=save_traces,
save_format="tf",
)
save_merlin_metadata(export_path, self, self.schema, None)
input_schema = self.schema
output_schema = get_output_schema(export_path)
save_merlin_metadata(export_path, self, input_schema, output_schema)

@classmethod
def load(cls, export_path: Union[str, os.PathLike]) -> "Model":
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/tf/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from merlin.datasets.synthetic import generate_data
from merlin.io.dataset import Dataset
from merlin.models.tf.utils import testing_utils, tf_utils
from merlin.models.utils import schema_utils
from merlin.schema import ColumnSchema, Schema, Tags


Expand Down Expand Up @@ -689,10 +690,27 @@ def test_save_and_load(tmpdir):
)
model.save(tmpdir)
reloaded_model = mm.Model.load(tmpdir)

saved_input_schema = schema_utils.tensorflow_metadata_json_to_schema(
f"{tmpdir}/.merlin/input_schema.json"
)
saved_output_schema = schema_utils.tensorflow_metadata_json_to_schema(
f"{tmpdir}/.merlin/output_schema.json"
)

signature_input_keys = set(
reloaded_model.signatures["serving_default"].structured_input_signature[1].keys()
)
assert signature_input_keys == {"user_age"}
signature_output_keys = set(
reloaded_model.signatures["serving_default"].structured_outputs.keys()
)
assert signature_input_keys == {"user_age"} == set(saved_input_schema.column_names)
assert (
signature_output_keys
== {"click/binary_classification_task"}
== set(saved_output_schema.column_names)
)

test_case = TestCase()
test_case.assertAllClose(
model.predict(dataset, batch_size=10), reloaded_model.predict(dataset, batch_size=10)
Expand Down

0 comments on commit e08a72c

Please sign in to comment.