Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Avoid checkpoint conversion, move encoding logic to checkpoints #28794

Merged
merged 71 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
8acd735
Avoid checkpoint conversion
Yard1 Sep 27, 2022
f5058a3
Fix legacy session tests
Yard1 Sep 27, 2022
442179a
Fix horovod
Yard1 Sep 27, 2022
8f891a2
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Sep 27, 2022
0491fe3
Update checkpoint.py
Yard1 Sep 27, 2022
bac4cfb
Update checkpoint.py
Yard1 Sep 28, 2022
bd95d91
Encoded attr
Yard1 Sep 28, 2022
6e24f54
Lint
Yard1 Sep 28, 2022
e031339
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Sep 28, 2022
f763cd5
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Sep 29, 2022
16dfa01
Handle preprocessor
Yard1 Sep 29, 2022
19fc16d
Deprecate encode/decode in backend
Yard1 Sep 29, 2022
870fa59
Make sure we maintain backwards compat
Yard1 Sep 29, 2022
77f8083
Make sure HF has the right checkpoint type
Yard1 Sep 29, 2022
4f7189f
Fix
Yard1 Sep 29, 2022
e8f6af3
GPU fix
Yard1 Sep 29, 2022
11a5edf
Cleanup
Yard1 Sep 29, 2022
094ed5f
Update test_gpu.py
Yard1 Sep 29, 2022
f887cc4
Fix
Yard1 Sep 29, 2022
5286e63
Merge branch 'train_avoid_checkpoint_conversion' of https://github.co…
Yard1 Sep 29, 2022
78ae564
Update test_tensorflow_trainer.py
Yard1 Sep 30, 2022
c913fe5
Update test_tensorflow_trainer.py
Yard1 Sep 30, 2022
3304a21
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 3, 2022
cb22cbe
Warn during an exception in serialization.py
Yard1 Oct 3, 2022
69750b6
Update python/ray/train/_internal/utils.py
Yard1 Oct 4, 2022
26f0958
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 4, 2022
dd1c255
Make get_checkpoint_class private
Yard1 Oct 4, 2022
cebf4bf
Rephrase
Yard1 Oct 4, 2022
de28c98
Extra tests, fail fast for torch
Yard1 Oct 4, 2022
d0cfeeb
Error fixes
Yard1 Oct 4, 2022
2121b22
Missed this
Yard1 Oct 4, 2022
1616411
Fix
Yard1 Oct 4, 2022
6ffd0c9
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 5, 2022
0d835b4
Apply feedback
Yard1 Oct 5, 2022
7341ed7
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 5, 2022
dd92589
Refactor
Yard1 Oct 5, 2022
e34c942
Nit
Yard1 Oct 5, 2022
2c567f8
Fix
Yard1 Oct 5, 2022
76d4243
Fix
Yard1 Oct 5, 2022
9c1f687
Tweak example
Yard1 Oct 5, 2022
77a4f04
Fix
Yard1 Oct 5, 2022
3f1aaf0
Fix
Yard1 Oct 5, 2022
acee9f2
Tweaks
Yard1 Oct 5, 2022
172ba5e
Add checkpointing to Torch PBT example
Yard1 Oct 5, 2022
30291ad
Revert
Yard1 Oct 5, 2022
2110f0d
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 6, 2022
0ef49d9
Do not print warning yet
Yard1 Oct 6, 2022
fe3a584
Fix
Yard1 Oct 6, 2022
363f0eb
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 7, 2022
ce07451
Fix
Yard1 Oct 7, 2022
bc825cb
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 7, 2022
088e711
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 7, 2022
29d589d
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 11, 2022
56e4242
Fix CI
Yard1 Oct 11, 2022
6f00cfe
Fix
Yard1 Oct 12, 2022
747a3d2
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 12, 2022
00787b4
Break up the tests
Yard1 Oct 12, 2022
4d4498c
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 12, 2022
8f6080d
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 13, 2022
9e447be
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 13, 2022
c7ee9bf
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 14, 2022
a29b877
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 17, 2022
fedb86a
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 25, 2022
35f9459
Fixes
Yard1 Oct 25, 2022
b540d67
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 26, 2022
79a8b1f
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 26, 2022
2542d44
Apply feedback from code review
Yard1 Oct 26, 2022
8cdd6e9
Remove serialization error hint
Yard1 Oct 26, 2022
b82a8b9
Remove test
Yard1 Oct 26, 2022
fe6c7b9
Merge branch 'ray-project:master' into train_avoid_checkpoint_conversion
Yard1 Oct 26, 2022
5755809
Merge branch 'master' into train_avoid_checkpoint_conversion
Yard1 Oct 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,26 @@ def convert_ndarray_batch_to_tf_tensor_batch(
return batch


# This is not foolproof, but it's better than nothing
# The place it is used in will be deprecated soon
def contains_tensorflow_object(obj):
if hasattr(obj, "__module__") and (
"keras" in obj.__module__ or "tensorflow" in obj.__module__
):
return True
elif isinstance(obj, dict):
for k, v in obj.items():
if contains_tensorflow_object(k):
return True
if contains_tensorflow_object(v):
return True
elif isinstance(obj, (list, tuple)):
for v in obj:
if contains_tensorflow_object(v):
return True
return False


def get_type_spec(
schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
columns: Union[str, List[str]],
Expand Down
9 changes: 6 additions & 3 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def _get_temporary_checkpoint_dir(self) -> str:
)
return os.path.join(tmp_dir_path, checkpoint_dir_name)

def _save_checkpoint_metadata_in_directory(self, path: str) -> None:
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def _to_directory(self, path: str) -> None:
if self._data_dict or self._obj_ref:
# This is a object ref or dict
Expand Down Expand Up @@ -547,9 +552,7 @@ def _to_directory(self, path: str) -> None:
f"No valid location found for checkpoint {self}: {self._uri}"
)

checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)
self._save_checkpoint_metadata_in_directory(path)

def to_directory(self, path: Optional[str] = None) -> str:
"""Write checkpoint data to directory.
Expand Down
9 changes: 4 additions & 5 deletions python/ray/air/examples/horovod/horovod_pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from torchvision import datasets, transforms

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.horovod import HorovodTrainer
from ray.train.torch.torch_checkpoint import TorchCheckpoint
import ray.train.torch


Expand Down Expand Up @@ -152,12 +152,11 @@ def train_func(config):
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda
)
if save_model_as_dict:
checkpoint_dict = dict(model=model.state_dict())
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

else:
checkpoint_dict = dict(model=model)
checkpoint_dict = Checkpoint.from_dict(checkpoint_dict)
checkpoint = TorchCheckpoint.from_model(model)
results.append(loss)
session.report(dict(loss=loss), checkpoint=checkpoint_dict)
session.report(dict(loss=loss), checkpoint=checkpoint)

# Only used for testing.
return results
Expand Down
11 changes: 5 additions & 6 deletions python/ray/air/examples/pytorch/torch_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import ray.train as train
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig


Expand Down Expand Up @@ -48,8 +48,7 @@ def validate_epoch(dataloader, model, loss_fn):
import copy

model_copy = copy.deepcopy(model)
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result
return model_copy.cpu().state_dict(), loss


def train_func(config):
Expand All @@ -76,12 +75,12 @@ def train_func(config):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

results = []

for _ in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
result = validate_epoch(validation_loader, model, loss_fn)
state_dict, loss = validate_epoch(validation_loader, model, loss_fn)
result = dict(loss=loss)
results.append(result)
session.report(result)
session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict))

return results

Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def initialize_session(
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint,
encode_data_fn=self._backend.encode_data,
encode_data_fn=self._backend._encode_data,
)
)

Expand Down
17 changes: 14 additions & 3 deletions python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,20 @@ def _process_checkpoint(
"""Ray Train entrypoint. Perform all processing for a checkpoint."""
# Get checkpoint from first worker.
checkpoint_data = checkpoint_results[0].data
checkpoint_metadata = checkpoint_results[0].metadata or {}

# Decode checkpoint.
checkpoint_data = decode_checkpoint_fn(checkpoint_data)
# TODO(ml-team): Remove once we remove Backend.decode_data
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw it seems that the semantics for scoring is not consistent across across Tune and our different Trainers.

Currently we do the following:

  1. For DL trainers, we check the checkpoint dict itself for score_attribute
  2. For non-DL trainers (which goes through TuneReportCallback) and for basic Tune, we check the metrics associated with checkpoint for score_attribute, not the checkpoint itself.

With the new API, we should push to using approach 2 for everything in followups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, let me make a note of that

# This is too risky for now (will be saved to a tmp dir)
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
# if checkpoint_data.uri:
# # TODO: ensure that the dir is created in the proper place
# checkpoint_data = checkpoint_data.to_directory()
# checkpoint_data_dict = {}
Yard1 marked this conversation as resolved.
Show resolved Hide resolved

score_attr = self._checkpoint_strategy.checkpoint_score_attribute
if (
self._checkpoint_strategy.num_to_keep != 0
and score_attr not in checkpoint_metadata
and score_attr not in checkpoint_data
):
raise ValueError(
Expand All @@ -122,7 +129,11 @@ def _process_checkpoint(
dir_or_data=checkpoint_data,
checkpoint_id=self._latest_checkpoint_id,
storage_mode=CheckpointStorage.MEMORY,
metrics={score_attr: checkpoint_data.get(score_attr, 0.0)},
metrics={
score_attr: checkpoint_metadata.get(
score_attr, checkpoint_data.get(score_attr, 0.0)
)
},
)
self.register_checkpoint(checkpoint=tracked_checkpoint)

Expand Down
51 changes: 36 additions & 15 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import platform
import queue
import sys
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -51,7 +52,8 @@ class TrialInfo:
@dataclass
class TrainingResult:
type: TrainingResultType
data: Dict
data: Union[Dict, Checkpoint]
metadata: Optional[Dict] = None


# TODO(xwjiang): This needs a better name.
Expand All @@ -68,8 +70,9 @@ def __init__(
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None,
checkpoint: Optional[Checkpoint] = None,
# Deprecated
encode_data_fn: Optional[Callable] = None,
detailed_autofilled_metrics: bool = False,
):

Yard1 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -80,7 +83,7 @@ def __init__(
self.world_size = world_size
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint
self.loaded_checkpoint = checkpoint

# Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn:
Expand Down Expand Up @@ -240,9 +243,9 @@ def _report_legacy(self, **kwargs):
if self.ignore_report:
return

kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs))
kwargs = self._auto_fill_metrics(kwargs)

result = TrainingResult(TrainingResultType.REPORT, kwargs)
result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)
Expand All @@ -269,22 +272,26 @@ def _report_thread_runner_error(self, block=False):
except queue.Empty:
pass

def checkpoint(self, **kwargs):
def checkpoint(self, checkpoint: Checkpoint):
"""Adds kwargs to the queue to be consumed by main thread.

Also stores the checkpoint in ``self.loaded_checkpoint``.
"""

# Update session checkpoint to latest checkpoint.
self.loaded_checkpoint = kwargs
self.loaded_checkpoint = checkpoint

# Only store checkpoints on worker with rank 0.
if self.world_rank != 0:
kwargs = {}
else:
kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs))

result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
checkpoint = None
elif checkpoint:
checkpoint = self._encode_data_fn(checkpoint)

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=checkpoint,
metadata=self._auto_fill_checkpoint_metrics({}),
)
# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

Expand All @@ -294,9 +301,23 @@ def checkpoint(self, **kwargs):

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

# Special case: early fail for Torch tensors
if "torch" in sys.modules:
from ray.air._internal.torch_utils import contains_tensor

if contains_tensor(metrics):
raise ValueError(
"Passing objects containg Torch tensors as metrics "
"is not supported as it will throw an exception on "
"deserialization. You can either convert the tensors "
"to Python objects or use a `TorchCheckpoint` as the "
"`checkpoint` argument of `ray.air.session.report` to "
"store your Torch objects."
)

if checkpoint:
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self.checkpoint(checkpoint)
self._report_legacy(**metrics)


Expand Down
18 changes: 18 additions & 0 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import logging
from pathlib import Path
import traceback

from typing import (
Tuple,
Expand Down Expand Up @@ -58,6 +59,23 @@ def check_for_failure(
return False, exc
except Exception as exc:
# Other (e.g. training) errors should be directly raised
# If exception is raised in the serialization module,
# we guide the user to look for that
# test_data_parallel_trainer.py::test_serialization_errors
# tests that this is being printed
if "serialization.py" in traceback.format_exc():
logger.error(
"An exception raised here from the serialization module "
"is most likely caused by an issue with deserialization "
"(eg. with Torch models or tensors). "
"Ensure that you are reporting a Checkpoint type specific "
"to the framework you are using (eg. `TorchCheckpoint` if "
"checkpointing a Torch model). Those Checkpoint types "
"contain special serialization/deserialization logic "
"that helps avoid deserialization exceptions. "
"No special handling logic is applied for objects "
"passed in the `metrics` dict in the `report()` method!"
)
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
raise StartTraceback from exc

return True, None
Expand Down
65 changes: 62 additions & 3 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
import logging
from typing import TypeVar, Dict
import warnings
from typing import Type, TypeVar, Dict

from ray.air.checkpoint import Checkpoint
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI

from ray.util.annotations import Deprecated, DeveloperAPI
from ray.widgets import make_table_html_repr

EncodedData = TypeVar("EncodedData")

logger = logging.getLogger(__name__)

# This is used in several places to print a warning.
_encode_decode_deprecation_message = (
"``encode_data`` and ``decode_data`` are deprecated in favor of "
"framework-specific ``ray.air.Checkpoint`` subclasses (reported "
"using ``ray.air.session.report()``) which can implement "
"encoding and decoding logic. In the future, ``encode_data`` and "
"``decode_data`` will throw an exception if overriden."
)


def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]):
return
# Do not print warnings in 2.1 yet.
# TODO(ml-team): Change this once we have full API parity with framework
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning
# warnings.warn(
# f"You have reported a checkpoint with the `{Checkpoint}` "
# "type, but the intended checkpoint type for the Trainer "
# f"you are using is `{expected_checkpoint_cls}`. "
# "Not using the intended checkpoint type may cause "
# "exceptions or other issues, especially during "
# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just remove this for now or move to session.report? We want the warning to be printed in session.report anyways (we are going to be hard deprecating train.report).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put it in session.report but then the logic to check whether the checkpoint is good or bad would need to be moved there too, and that would make it nasty to work with.



@DeveloperAPI
class BackendConfig:
Expand Down Expand Up @@ -46,6 +73,37 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass

@classmethod
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``encode_data`` is deprecated."""
if cls.encode_data != Backend.encode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
# We wrap the return of encode_data in dict in case it is
# not a dict itself.
checkpoint = checkpoint.from_dict(
{"encoded_data": cls.encode_data(checkpoint.to_dict())}
)
return checkpoint

@classmethod
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``decode_data`` is deprecated."""
if cls.decode_data != Backend.decode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
checkpoint_dict = checkpoint.to_dict()
# If "encoded_data" is not in the dict, then the data was
# not encoded, but the user may want to just do decoding
# anyway.
checkpoint = checkpoint.from_dict(
cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict))
)
return checkpoint

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
Expand All @@ -56,6 +114,7 @@ def encode_data(data_dict: Dict) -> EncodedData:

return data_dict

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
Expand Down
Loading