-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints #28794
Changes from 66 commits
8acd735
f5058a3
442179a
8f891a2
0491fe3
bac4cfb
bd95d91
6e24f54
e031339
f763cd5
16dfa01
19fc16d
870fa59
77f8083
4f7189f
e8f6af3
11a5edf
094ed5f
f887cc4
5286e63
78ae564
c913fe5
3304a21
cb22cbe
69750b6
26f0958
dd1c255
cebf4bf
de28c98
d0cfeeb
2121b22
1616411
6ffd0c9
0d835b4
7341ed7
dd92589
e34c942
2c567f8
76d4243
9c1f687
77a4f04
3f1aaf0
acee9f2
172ba5e
30291ad
2110f0d
0ef49d9
fe3a584
363f0eb
ce07451
bc825cb
088e711
29d589d
56e4242
6f00cfe
747a3d2
00787b4
4d4498c
8f6080d
9e447be
c7ee9bf
a29b877
fedb86a
35f9459
b540d67
79a8b1f
2542d44
8cdd6e9
b82a8b9
fe6c7b9
5755809
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,13 +101,20 @@ def _process_checkpoint( | |
"""Ray Train entrypoint. Perform all processing for a checkpoint.""" | ||
# Get checkpoint from first worker. | ||
checkpoint_data = checkpoint_results[0].data | ||
checkpoint_metadata = checkpoint_results[0].metadata or {} | ||
|
||
# Decode checkpoint. | ||
checkpoint_data = decode_checkpoint_fn(checkpoint_data) | ||
# TODO(ml-team): Remove once we remove Backend.decode_data | ||
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw it seems that the semantics for scoring is not consistent across across Tune and our different Trainers. Currently we do the following:
With the new API, we should push to using approach 2 for everything in followups. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely, let me make a note of that |
||
# This is too risky for now (will be saved to a tmp dir) | ||
Yard1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# if checkpoint_data.uri: | ||
# # TODO: ensure that the dir is created in the proper place | ||
# checkpoint_data = checkpoint_data.to_directory() | ||
# checkpoint_data_dict = {} | ||
Yard1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
score_attr = self._checkpoint_strategy.checkpoint_score_attribute | ||
if ( | ||
self._checkpoint_strategy.num_to_keep != 0 | ||
and score_attr not in checkpoint_metadata | ||
and score_attr not in checkpoint_data | ||
): | ||
raise ValueError( | ||
|
@@ -122,7 +129,11 @@ def _process_checkpoint( | |
dir_or_data=checkpoint_data, | ||
checkpoint_id=self._latest_checkpoint_id, | ||
storage_mode=CheckpointStorage.MEMORY, | ||
metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, | ||
metrics={ | ||
score_attr: checkpoint_metadata.get( | ||
score_attr, checkpoint_data.get(score_attr, 0.0) | ||
) | ||
}, | ||
) | ||
self.register_checkpoint(checkpoint=tracked_checkpoint) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,43 @@ | ||
import logging | ||
from typing import TypeVar, Dict | ||
import warnings | ||
from typing import Type, TypeVar, Dict | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.train._internal.utils import Singleton | ||
from ray.train._internal.worker_group import WorkerGroup | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
from ray.util.annotations import Deprecated, DeveloperAPI | ||
from ray.widgets import make_table_html_repr | ||
|
||
EncodedData = TypeVar("EncodedData") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# This is used in several places to print a warning. | ||
_encode_decode_deprecation_message = ( | ||
"``encode_data`` and ``decode_data`` are deprecated in favor of " | ||
"framework-specific ``ray.air.Checkpoint`` subclasses (reported " | ||
"using ``ray.air.session.report()``) which can implement " | ||
"encoding and decoding logic. In the future, ``encode_data`` and " | ||
"``decode_data`` will throw an exception if overriden." | ||
) | ||
|
||
|
||
def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]): | ||
return | ||
# Do not print warnings in 2.1 yet. | ||
# TODO(ml-team): Change this once we have full API parity with framework | ||
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning | ||
# warnings.warn( | ||
# f"You have reported a checkpoint with the `{Checkpoint}` " | ||
# "type, but the intended checkpoint type for the Trainer " | ||
# f"you are using is `{expected_checkpoint_cls}`. " | ||
# "Not using the intended checkpoint type may cause " | ||
# "exceptions or other issues, especially during " | ||
# "serialization and deserialization. The checkpoint " | ||
# "type will be changed automatically. " | ||
# "This behavior may change in the future." | ||
# ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just remove this for now or move to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd put it in session.report but then the logic to check whether the checkpoint is good or bad would need to be moved there too, and that would make it nasty to work with. |
||
|
||
|
||
@DeveloperAPI | ||
class BackendConfig: | ||
|
@@ -46,6 +73,37 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig): | |
"""Logic for shutting down the backend.""" | ||
pass | ||
|
||
@classmethod | ||
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint: | ||
"""Temporary method until ``encode_data`` is deprecated.""" | ||
if cls.encode_data != Backend.encode_data: | ||
warnings.warn( | ||
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 | ||
) | ||
# We wrap the return of encode_data in dict in case it is | ||
# not a dict itself. | ||
checkpoint = checkpoint.from_dict( | ||
{"encoded_data": cls.encode_data(checkpoint.to_dict())} | ||
) | ||
return checkpoint | ||
|
||
@classmethod | ||
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint: | ||
"""Temporary method until ``decode_data`` is deprecated.""" | ||
if cls.decode_data != Backend.decode_data: | ||
warnings.warn( | ||
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 | ||
) | ||
checkpoint_dict = checkpoint.to_dict() | ||
# If "encoded_data" is not in the dict, then the data was | ||
# not encoded, but the user may want to just do decoding | ||
# anyway. | ||
checkpoint = checkpoint.from_dict( | ||
cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict)) | ||
) | ||
return checkpoint | ||
|
||
@Deprecated(message=_encode_decode_deprecation_message) | ||
@staticmethod | ||
def encode_data(data_dict: Dict) -> EncodedData: | ||
"""Logic to encode a data dict before sending to the driver. | ||
|
@@ -56,6 +114,7 @@ def encode_data(data_dict: Dict) -> EncodedData: | |
|
||
return data_dict | ||
|
||
@Deprecated(message=_encode_decode_deprecation_message) | ||
@staticmethod | ||
def decode_data(encoded_data: EncodedData) -> Dict: | ||
"""Logic to decode an encoded data dict. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!