Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Maintain checkpoint type information during serialization #28387

Merged
merged 29 commits into from
Sep 26, 2022

Conversation

bveeramani
Copy link
Member

@bveeramani bveeramani commented Sep 8, 2022

Signed-off-by: Balaji Veeramani balaji@anyscale.com

Depends on:

Why are these changes needed?

These changes are needed to fix the errors described in #28134.

Related issue number

Closes #28134

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@Yard1
Copy link
Member

Yard1 commented Sep 12, 2022

Ah, I am a bit worried about this dict subclass. While it is fine for serialization, I don't think we should expose it to the user, as there are multiple points where it would become a regular dict again. Maybe we can have a reserved key instead of subclassing, or have it be purely internal and only used in setstate?

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just one doc nit

python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
@xwjiang2010
Copy link
Contributor

xwjiang2010 commented Sep 13, 2022

Can this be used beyond just the type of checkpoint?
Can I have something like below?

class TensorflowCheckpoint(Checkpoint):

    class Flavor(Enum):
        # Various flavors with which TensorflowCheckpoint is generated.
        # This is necessary metadata to decide how to load model from a checkpoint.
        MODEL_WEIGHTS = 1
        SAVED_MODEL = 2
        H5 = 3

    def __init__(self,
        local_path: Optional[str] = None,
        data_dict: Optional[dict] = None,
        uri: Optional[str] = None,
        obj_ref: Optional["ray.ObjectRef"] = None
    ):
        super().__init__(local_path, data_dict, uri, obj_ref)
        self._flavor = None  # set when from_saved_model, from_h5, from_model etc
tf_checkpoint = TensorflowCheckpoint.from_saved_model("my_model")
session.report(checkpoint=tf_checkpoint)

batch_predictor = BatchPredictor(result.checkpoint, TensorflowPredictor)
...

And not only result.checkpoint is a TensorflowCheckpoint, but it also has the correct "flavor" set?

@krfricke @amogkam @bveeramani

For reference, this is the PR: https://github.com/ray-project/ray/pull/28474/files

@amogkam
Copy link
Contributor

amogkam commented Sep 13, 2022

Also, we should make sure that this only triggers in the base checkpoint class - if I do TorchCheckpoint.from_dict() I would want to get a TorchCheckpoint back in any case (e.g. if I pass a regular Checkpoint-CheckpointDict).

we should raise an error here if cls is incompatible with what's saved in the checkpoint metadata.

We should prevent users from restoring a TensorflowCheckpoint as a TorchCheckpoint for example.

@bveeramani
Copy link
Member Author

Can this be used beyond just the type of checkpoint?
Can I have something like below?

@xwjiang2010 I don't think you can store arbitrary metadata like Flavor in the current implementation, but I think we could extend Checkpoint to support that functionality. @krfricke @amogkam thoughts?

@xwjiang2010
Copy link
Contributor

Can this be used beyond just the type of checkpoint?
Can I have something like below?

@xwjiang2010 I don't think you can store arbitrary metadata like Flavor in the current implementation, but I think we could extend Checkpoint to support that functionality. @krfricke @amogkam thoughts?

Thanks @bveeramani ! Yes, I am proposing to extend it. I think this is a reasonable developer expectation.

Something like
checkpoint._invariant_field = ["_flavor"] # The fields that should maintain invariant across ser/deser.

bveeramani and others added 11 commits September 13, 2022 17:28
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bveeramani! Looks great so far.

Primary things are:

  1. Let's add docstrings to all the Developer APIs
  2. Can we add more test coverage? In particular what is the type of the checkpoint if you save a StubCheckpoint but load back via Checkpoint.from_*? Also can we test that the metadata is saved even in the dict/bytes->dir and dir->dict/bytes workflows?

python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
python/ray/air/checkpoint.py Show resolved Hide resolved
python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
python/ray/air/checkpoint.py Show resolved Hide resolved
python/ray/air/checkpoint.py Show resolved Hide resolved
python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

python/ray/air/tests/test_checkpoints.py Outdated Show resolved Hide resolved
python/ray/air/tests/test_checkpoints.py Outdated Show resolved Hide resolved
Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bveeramani- lgtm! just some minor comments on private vs. developer vs. public api based on our offline conversation

python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
python/ray/air/checkpoint.py Show resolved Hide resolved
python/ray/air/checkpoint.py Show resolved Hide resolved
python/ray/air/checkpoint.py Outdated Show resolved Hide resolved
@amogkam
Copy link
Contributor

amogkam commented Sep 26, 2022

@bveeramani can you merge in master to fix the CI?

@amogkam
Copy link
Contributor

amogkam commented Sep 26, 2022

failing tests are unrelated, going to merge

@amogkam amogkam merged commit 5034544 into ray-project:master Sep 26, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
…project#28387)

These changes are needed to fix the errors described in ray-project#28134.

Signed-off-by: Balaji Veeramani <bveeramani@berkeley.edu>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AIR] Maintain checkpoint subclass information during serialization
5 participants