Skip to content

Commit

Permalink
[AIR] Allow user to pass model to TensorflowCheckpoint.get_model (#…
Browse files Browse the repository at this point in the history
…31203)

When you resume training a TensorFlow model, you may need to create an unnecessary lambda to load model weights. This is because TensorflowCheckplint.get_model expects a model definition, but you may have already constructed your model. This PR improves the UX by letting users directly pass a model to get_model.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
  • Loading branch information
bveeramani authored Jan 4, 2023
1 parent 892b4f0 commit c83a8c7
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 51 deletions.
123 changes: 77 additions & 46 deletions python/ray/train/tensorflow/tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Optional, Union

from enum import Enum
from os import path
Expand Down Expand Up @@ -56,7 +56,7 @@ def from_model(
model.
The checkpoint created with this method needs to be paired with
`model_definition` when used.
`model` when used.
Args:
model: The Keras model, whose weights are stored in the checkpoint.
Expand Down Expand Up @@ -86,7 +86,7 @@ def from_h5(
model from H5 format.
The checkpoint generated by this method contains all the information needed.
Thus no `model_definition` is needed to be supplied when using this checkpoint.
Thus no `model` is needed to be supplied when using this checkpoint.
`file_path` must maintain validity even after this function returns.
Some new files/directories may be added to the parent directory of `file_path`,
Expand Down Expand Up @@ -155,7 +155,7 @@ def from_saved_model(
model from SavedModel format.
The checkpoint generated by this method contains all the information needed.
Thus no `model_definition` is needed to be supplied when using this checkpoint.
Thus no `model` is needed to be supplied when using this checkpoint.
`dir_path` must maintain validity even after this function returns.
Some new files/directories may be added to `dir_path`, as a side effect
Expand Down Expand Up @@ -210,56 +210,87 @@ def from_saved_model(

def get_model(
self,
model_definition: Optional[
Union[Callable[[], tf.keras.Model], Type[tf.keras.Model]]
] = None,
model: Optional[Union[tf.keras.Model, Callable[[], tf.keras.Model]]] = None,
model_definition: Optional[Callable[[], tf.keras.Model]] = None,
) -> tf.keras.Model:
"""Retrieve the model stored in this checkpoint.
Args:
model_definition: This arg is expected only if the original checkpoint
model: This arg is expected only if the original checkpoint
was created via `TensorflowCheckpoint.from_model`.
model_definition: This parameter is deprecated. Use `model` instead.
Returns:
The Tensorflow Keras model stored in the checkpoint.
"""
if model_definition is not None:
warnings.warn(
"The `model_definition` parameter is deprecated. Use the `model` "
"parameter instead.",
DeprecationWarning,
)
model = model_definition
if model is not None and self._flavor is not self.Flavor.MODEL_WEIGHTS:
warnings.warn(
"TensorflowCheckpoint was created from "
"TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`, which already contains all the "
"information needed. This means: "
"If you are using BatchPredictor, you should do "
"`BatchPredictor.from_checkpoint(checkpoint, TensorflowPredictor)`"
" by removing kwargs `model=`. "
"If you are using TensorflowPredictor directly, you should do "
"`TensorflowPredictor.from_checkpoint(checkpoint)` by "
"removing kwargs `model=`."
)
if self._flavor is self.Flavor.MODEL_WEIGHTS and model is None:
raise ValueError(
"Expecting `model` argument when checkpoint is "
"saved through `TensorflowCheckpoint.from_model()`."
)

if self._flavor is self.Flavor.MODEL_WEIGHTS:
if not model_definition:
raise ValueError(
"Expecting `model_definition` argument when checkpoint is "
"saved through `TensorflowCheckpoint.from_model()`."
)
model_weights, _ = _load_checkpoint_dict(self, "TensorflowTrainer")
model = model_definition()
model.set_weights(model_weights)
return model
model = self._get_model_from_weights(model)
elif self._flavor is self.Flavor.H5:
model = self._get_model_from_h5()
elif self._flavor == self.Flavor.SAVED_MODEL:
model = self._get_model_from_saved_model()
else:
if model_definition:
warnings.warn(
"TensorflowCheckpoint was created from "
"TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`, which already contains all the "
"information needed. This means: "
"If you are using BatchPredictor, you should do "
"`BatchPredictor.from_checkpoint(checkpoint, TensorflowPredictor)`"
" by removing kwargs `model_definition=`. "
"If you are using TensorflowPredictor directly, you should do "
"`TensorflowPredictor.from_checkpoint(checkpoint)` by "
"removing kwargs `model_definition=`."
)
with self.as_directory() as checkpoint_dir:
if self._flavor == self.Flavor.H5:
return keras.models.load_model(
os.path.join(checkpoint_dir, self._h5_file_path)
)
elif self._flavor == self.Flavor.SAVED_MODEL:
return keras.models.load_model(checkpoint_dir)
else:
raise RuntimeError(
"Avoid directly using `from_dict` or "
"`from_directory` directly. Make sure "
"that the checkpoint was generated by "
"`TensorflowCheckpoint.from_model`, "
"`TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`."
)
raise RuntimeError(
"Avoid directly using `from_dict` or "
"`from_directory` directly. Make sure "
"that the checkpoint was generated by "
"`TensorflowCheckpoint.from_model`, "
"`TensorflowCheckpoint.from_saved_model` or "
"`TensorflowCheckpoint.from_h5`."
)

return model

def _get_model_from_weights(
self, model: Union[tf.keras.Model, Callable[[], tf.keras.Model]]
) -> tf.keras.Model:
assert self._flavor is self.Flavor.MODEL_WEIGHTS
assert model is not None

if callable(model):
model = model()

model_weights, _ = _load_checkpoint_dict(self, "TensorflowTrainer")
model.set_weights(model_weights)

return model

def _get_model_from_h5(self) -> tf.keras.Model:
assert self._flavor is self.Flavor.H5

with self.as_directory() as checkpoint_dir:
return keras.models.load_model(
os.path.join(checkpoint_dir, self._h5_file_path)
)

def _get_model_from_saved_model(self) -> tf.keras.Model:
assert self._flavor is self.Flavor.SAVED_MODEL

with self.as_directory() as checkpoint_dir:
return keras.models.load_model(checkpoint_dir)
11 changes: 9 additions & 2 deletions python/ray/train/tests/test_tensorflow_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def get_model():
)


def test_model_definition_raises_deprecation_warning():
model = get_model()
checkpoint = TensorflowCheckpoint.from_model(model)
with pytest.deprecated_call():
checkpoint.get_model(model_definition=get_model)


def compare_weights(w1: List[ndarray], w2: List[ndarray]) -> bool:
if not len(w1) == len(w2):
return False
Expand All @@ -58,7 +65,7 @@ def test_from_model(self):
checkpoint = TensorflowCheckpoint.from_model(
self.model, preprocessor=DummyPreprocessor(1)
)
loaded_model = checkpoint.get_model(model_definition=get_model)
loaded_model = checkpoint.get_model(model=get_model)
preprocessor = checkpoint.get_preprocessor()

assert compare_weights(loaded_model.get_weights(), self.model.get_weights())
Expand Down Expand Up @@ -92,7 +99,7 @@ def test_from_saved_model_warning_with_model_definition(self):
preprocessor=DummyPreprocessor(1),
)
with pytest.warns(None):
loaded_model = checkpoint.get_model(model_definition=get_model)
loaded_model = checkpoint.get_model(model=get_model)
preprocessor = checkpoint.get_preprocessor()
assert compare_weights(self.model.get_weights(), loaded_model.get_weights())
assert preprocessor.multiplier == 1
Expand Down
5 changes: 2 additions & 3 deletions python/ray/train/tests/test_tensorflow_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,14 @@ def test_tensorflow_checkpoint():

checkpoint = TensorflowCheckpoint.from_model(model, preprocessor=preprocessor)
assert (
checkpoint.get_model(model_definition=build_raw_model).get_weights()
== model.get_weights()
checkpoint.get_model(model=build_raw_model).get_weights() == model.get_weights()
)

with checkpoint.as_directory() as path:
checkpoint = TensorflowCheckpoint.from_directory(path)
checkpoint_preprocessor = checkpoint.get_preprocessor()
assert (
checkpoint.get_model(model_definition=build_raw_model).get_weights()
checkpoint.get_model(model=build_raw_model).get_weights()
== model.get_weights()
)
assert checkpoint_preprocessor == preprocessor
Expand Down

0 comments on commit c83a8c7

Please sign in to comment.