-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints #28794
[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints #28794
Conversation
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
3fbf793
to
8acd735
Compare
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Add some descriptions? |
Test failures are unrelated. |
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can we run a small set of release tests for confirmation that this works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Yard1
How much of the code in this PR can be cleaned up if we sequence the ray.train.*
hard deprecations before this? I think we should do that deprecation before this PR since we are no longer getting this in for 2.1.
I'm wary about adding more tech debt in order to be backwards compatible with tech debt that we want to deprecate anyways. Then we can simplify this PR even more.
@@ -149,12 +149,11 @@ def train_func(config): | |||
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda | |||
) | |||
if save_model_as_dict: | |||
checkpoint_dict = dict(model=model.state_dict()) | |||
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
# "serialization and deserialization. The checkpoint " | ||
# "type will be changed automatically. " | ||
# "This behavior may change in the future." | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just remove this for now or move to session.report
? We want the warning to be printed in session.report
anyways (we are going to be hard deprecating train.report).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd put it in session.report but then the logic to check whether the checkpoint is good or bad would need to be moved there too, and that would make it nasty to work with.
|
||
def __getstate__(self) -> dict: | ||
if self._data_dict: | ||
state = self.__dict__.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these copies necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want to modify the underlying state when serializing (this is a shallow copy so it should be cheap)
|
||
def __setstate__(self, state: dict): | ||
if "_data_dict" in state: | ||
state = state.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto here: do we need to do this copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
Yeah I am fine with delaying this until we deprecate more things! |
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@amogkam updated to account for deprecations (let's wait with merging before we deduplicate examples though) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Yard1, looks good overall!
# Decode checkpoint. | ||
checkpoint_data = decode_checkpoint_fn(checkpoint_data) | ||
# TODO(ml-team): Remove once we remove Backend.decode_data | ||
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw it seems that the semantics for scoring is not consistent across across Tune and our different Trainers.
Currently we do the following:
- For DL trainers, we check the checkpoint dict itself for score_attribute
- For non-DL trainers (which goes through
TuneReportCallback
) and for basic Tune, we check the metrics associated with checkpoint for score_attribute, not the checkpoint itself.
With the new API, we should push to using approach 2 for everything in followups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely, let me make a note of that
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
…ckpoints (ray-project#28794)" This reverts commit 4c20503.
…ckpoints (#28794)" (#29784) This added dependencies from the HorovodConfig on TensorFlow and Torch. If either of these is not installed, (e.g. if the user is using Horovod with Torch and does not have TensorFlow installed), then they will run into a `ModuleNotFoundError`. https://github.com/ray-project/ray/blob/6b9a56d28e1029741feaa864257d75824fe36622/python/ray/train/horovod/config.py#L16-L17 Reverting this for now.
…c to checkpoints (ray-project#28794)" (ray-project#29784)" This reverts commit 57ea8bd.
…ray-project#28794) This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in session.py if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary. Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
…ckpoints (ray-project#28794)" (ray-project#29784) This added dependencies from the HorovodConfig on TensorFlow and Torch. If either of these is not installed, (e.g. if the user is using Horovod with Torch and does not have TensorFlow installed), then they will run into a `ModuleNotFoundError`. https://github.com/ray-project/ray/blob/6b9a56d28e1029741feaa864257d75824fe36622/python/ray/train/horovod/config.py#L16-L17 Reverting this for now. Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
ray-project#28794 fixed to avoid the issue discovered in ray-project#29784 (comment) Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Antoni Baum antoni.baum@protonmail.com
Why are these changes needed?
This PR avoids always converting to dictionary when reporting a checkpoint in Train and uses Checkpoints instead of dicts to transfer data. This is a low hanging fruit change for better consistency and performance with non-dict checkpoints. In order to facilitate that, the data encoding logic in Ray Train has been modified. Encoding and decoding is now done in the checkpoint classes. I believe this is the cleanest solution as it is both generic and inherently tied to the checkpoint itself - however, this has the downside of requiring users to use correct checkpoint classes for torch and horovod. In order to maintain backwards compatibility, the checkpoint class is automatically changed in
session.py
if a torch checkpoint is required (which has extra encoding and decoding logic to deal with serialization issues). Warnings are printed where necessary.The old way of encoding/decoding checkpoints, with the logic being defined in Backend classes, is soft deprecated but will still be used (while printing out a depreciation warning).
The only breaking change is that passing torch tensors/data in
train.report/session.report
is not allowed anymore. Considering that with the switch to the session API we do not expect users to return models/tensors but just metrics and that thetrain.report
API is soon to be hard deprecated, I do not think that this is an issue worth making a special case for. Happy to make it fully backwards compatible though!Finally, as a side effect, HuggingFaceTrainer will now return a HuggingFaceCheckpoint instead of a base Checkpoint (cc @bveeramani ).
This PR changes some of the tests to use the correct framework checkpoints - a followup PR will change examples and documentation (so we e. use TorchCheckpoint with TorchTrainer training UDFs).
Release tests: https://buildkite.com/ray-project/release-tests-pr/builds/19121
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.