Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints #28794

Merged
merged 71 commits into from
Oct 27, 2022

Conversation

Yard1
Copy link
Member

@Yard1 Yard1 commented Sep 26, 2022

Signed-off-by: Antoni Baum antoni.baum@protonmail.com

Why are these changes needed?

This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in session.py if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary.

The old way of encoding/decoding checkpoints, with the logic being defined in Backend classes, is soft deprecated but will still be used (while printing out a depreciation warning).

The only breaking change is that passing torch tensors/data in train.report/session.report is not allowed anymore. Considering that with the switch to the session API we do not expect users to return models/tensors but just metrics and that the train.report API is soon to be hard deprecated, I do not think that this is an issue worth making a special case for. Happy to make it fully backwards compatible though!

Finally, as a side effect, HuggingFaceTrainer will now return a HuggingFaceCheckpoint instead of a base Checkpoint (cc @bveeramani ).

This PR changes some of the tests to use the correct framework checkpoints - a followup PR will change examples and documentation (so we e. use TorchCheckpoint with TorchTrainer training UDFs).

Release tests: https://buildkite.com/ray-project/release-tests-pr/builds/19121

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@Yard1 Yard1 force-pushed the train_avoid_checkpoint_conversion branch from 3fbf793 to 8acd735 Compare September 27, 2022 17:46
@xwjiang2010
Copy link
Contributor

Add some descriptions?

@Yard1
Copy link
Member Author

Yard1 commented Sep 28, 2022

Test failures are unrelated.

@Yard1 Yard1 added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Sep 28, 2022
@Yard1 Yard1 removed the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Sep 29, 2022
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@Yard1 Yard1 changed the title [AIR] Avoid checkpoint conversion [AIR] Avoid checkpoint conversion, move encoding logic to checkpoints Sep 29, 2022
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Can we run a small set of release tests for confirmation that this works?

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1

How much of the code in this PR can be cleaned up if we sequence the ray.train.* hard deprecations before this? I think we should do that deprecation before this PR since we are no longer getting this in for 2.1.

I'm wary about adding more tech debt in order to be backwards compatible with tech debt that we want to deprecate anyways. Then we can simplify this PR even more.

@@ -149,12 +149,11 @@ def train_func(config):
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda
)
if save_model_as_dict:
checkpoint_dict = dict(model=model.state_dict())
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just remove this for now or move to session.report? We want the warning to be printed in session.report anyways (we are going to be hard deprecating train.report).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put it in session.report but then the logic to check whether the checkpoint is good or bad would need to be moved there too, and that would make it nasty to work with.


def __getstate__(self) -> dict:
if self._data_dict:
state = self.__dict__.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these copies necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to modify the underlying state when serializing (this is a shallow copy so it should be cheap)


def __setstate__(self, state: dict):
if "_data_dict" in state:
state = state.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here: do we need to do this copy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

python/ray/train/torch/torch_trainer.py Outdated Show resolved Hide resolved
python/ray/train/_internal/session.py Show resolved Hide resolved
@Yard1
Copy link
Member Author

Yard1 commented Oct 17, 2022

Yeah I am fine with delaying this until we deprecate more things!

Yard1 added 2 commits October 25, 2022 20:19
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@Yard1 Yard1 requested a review from amogkam October 25, 2022 20:50
@Yard1
Copy link
Member Author

Yard1 commented Oct 25, 2022

@amogkam updated to account for deprecations (let's wait with merging before we deduplicate examples though)

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Yard1, looks good overall!

# Decode checkpoint.
checkpoint_data = decode_checkpoint_fn(checkpoint_data)
# TODO(ml-team): Remove once we remove Backend.decode_data
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw it seems that the semantics for scoring is not consistent across across Tune and our different Trainers.

Currently we do the following:

  1. For DL trainers, we check the checkpoint dict itself for score_attribute
  2. For non-DL trainers (which goes through TuneReportCallback) and for basic Tune, we check the metrics associated with checkpoint for score_attribute, not the checkpoint itself.

With the new API, we should push to using approach 2 for everything in followups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, let me make a note of that

python/ray/train/_internal/checkpoint.py Outdated Show resolved Hide resolved
python/ray/train/_internal/utils.py Outdated Show resolved Hide resolved
python/ray/train/tests/test_session.py Outdated Show resolved Hide resolved
python/ray/train/tests/test_session.py Outdated Show resolved Hide resolved
Yard1 added 3 commits October 26, 2022 21:03
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@amogkam amogkam merged commit 4c20503 into ray-project:master Oct 27, 2022
@Yard1 Yard1 deleted the train_avoid_checkpoint_conversion branch October 27, 2022 18:24
matthewdeng added a commit to matthewdeng/ray that referenced this pull request Oct 27, 2022
amogkam pushed a commit that referenced this pull request Oct 27, 2022
…ckpoints (#28794)" (#29784)

This added dependencies from the HorovodConfig on TensorFlow and Torch. If either of these is not installed, (e.g. if the user is using Horovod with Torch and does not have TensorFlow installed), then they will run into a `ModuleNotFoundError`.

https://github.com/ray-project/ray/blob/6b9a56d28e1029741feaa864257d75824fe36622/python/ray/train/horovod/config.py#L16-L17

Reverting this for now.
Yard1 added a commit to Yard1/ray that referenced this pull request Oct 27, 2022
amogkam pushed a commit that referenced this pull request Oct 28, 2022
#28794 fixed to avoid the issue discovered in #29784 (comment)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
…ray-project#28794)

This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in session.py if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
…ckpoints (ray-project#28794)" (ray-project#29784)

This added dependencies from the HorovodConfig on TensorFlow and Torch. If either of these is not installed, (e.g. if the user is using Horovod with Torch and does not have TensorFlow installed), then they will run into a `ModuleNotFoundError`.

https://github.com/ray-project/ray/blob/6b9a56d28e1029741feaa864257d75824fe36622/python/ray/train/horovod/config.py#L16-L17

Reverting this for now.

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
ray-project#28794 fixed to avoid the issue discovered in ray-project#29784 (comment)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants