Skip to content

Commit

Permalink
Revert "[AIR] Avoid checkpoint conversion, move encoding logic to che…
Browse files Browse the repository at this point in the history
…ckpoints (ray-project#28794)"

This reverts commit 4c20503.
  • Loading branch information
matthewdeng committed Oct 27, 2022
1 parent 6b9a56d commit 6618d53
Show file tree
Hide file tree
Showing 24 changed files with 174 additions and 404 deletions.
20 changes: 0 additions & 20 deletions python/ray/air/_internal/tensorflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,6 @@ def convert_ndarray_batch_to_tf_tensor_batch(
return batch


# This is not foolproof, but it's better than nothing
# The place it is used in will be deprecated soon
def contains_tensorflow_object(obj):
if hasattr(obj, "__module__") and (
"keras" in obj.__module__ or "tensorflow" in obj.__module__
):
return True
elif isinstance(obj, dict):
for k, v in obj.items():
if contains_tensorflow_object(k):
return True
if contains_tensorflow_object(v):
return True
elif isinstance(obj, (list, tuple)):
for v in obj:
if contains_tensorflow_object(v):
return True
return False


def get_type_spec(
schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
columns: Union[str, List[str]],
Expand Down
9 changes: 3 additions & 6 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,6 @@ def _get_temporary_checkpoint_dir(self) -> str:
)
return os.path.join(tmp_dir_path, checkpoint_dir_name)

def _save_checkpoint_metadata_in_directory(self, path: str) -> None:
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def _to_directory(self, path: str) -> None:
if self._data_dict or self._obj_ref:
# This is a object ref or dict
Expand Down Expand Up @@ -552,7 +547,9 @@ def _to_directory(self, path: str) -> None:
f"No valid location found for checkpoint {self}: {self._uri}"
)

self._save_checkpoint_metadata_in_directory(path)
checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME)
with open(checkpoint_metadata_path, "wb") as file:
pickle.dump(self._metadata, file)

def to_directory(self, path: Optional[str] = None) -> str:
"""Write checkpoint data to directory.
Expand Down
2 changes: 1 addition & 1 deletion python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def initialize_session(
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint,
encode_data_fn=self._backend._encode_data,
encode_data_fn=self._backend.encode_data,
)
)

Expand Down
12 changes: 3 additions & 9 deletions python/ray/train/_internal/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ def _process_checkpoint(
"""Ray Train entrypoint. Perform all processing for a checkpoint."""
# Get checkpoint from first worker.
checkpoint_data = checkpoint_results[0].data
checkpoint_metadata = checkpoint_results[0].metadata or {}

# TODO(ml-team): Remove once we remove Backend.decode_data
checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict()
# Decode checkpoint.
checkpoint_data = decode_checkpoint_fn(checkpoint_data)

score_attr = self._checkpoint_strategy.checkpoint_score_attribute
if (
self._checkpoint_strategy.num_to_keep != 0
and score_attr not in checkpoint_metadata
and score_attr not in checkpoint_data
):
raise ValueError(
Expand All @@ -124,11 +122,7 @@ def _process_checkpoint(
dir_or_data=checkpoint_data,
checkpoint_id=self._latest_checkpoint_id,
storage_mode=CheckpointStorage.MEMORY,
metrics={
score_attr: checkpoint_metadata.get(
score_attr, checkpoint_data.get(score_attr, 0.0)
)
},
metrics={score_attr: checkpoint_data.get(score_attr, 0.0)},
)
self.register_checkpoint(checkpoint=tracked_checkpoint)

Expand Down
51 changes: 15 additions & 36 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import platform
import queue
import sys
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -52,8 +51,7 @@ class TrialInfo:
@dataclass
class TrainingResult:
type: TrainingResultType
data: Union[Dict, Checkpoint]
metadata: Optional[Dict] = None
data: Dict


# TODO(xwjiang): This needs a better name.
Expand All @@ -70,9 +68,8 @@ def __init__(
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
# TODO(xwjiang): Legacy Ray Train trainer clean up!
checkpoint: Optional[Checkpoint] = None,
# Deprecated
encode_data_fn: Optional[Callable] = None,
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False,
):

Expand All @@ -83,7 +80,7 @@ def __init__(
self.world_size = world_size
self.trial_info = trial_info
# TODO(xwjiang): Legacy Ray Train trainer clean up!
self.loaded_checkpoint = checkpoint
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint

# Function to encode checkpoint dict before sending to the driver.
if not encode_data_fn:
Expand Down Expand Up @@ -243,9 +240,9 @@ def _report_legacy(self, **kwargs):
if self.ignore_report:
return

kwargs = self._auto_fill_metrics(kwargs)
kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs))

result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs)
result = TrainingResult(TrainingResultType.REPORT, kwargs)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)
Expand All @@ -272,26 +269,22 @@ def _report_thread_runner_error(self, block=False):
except queue.Empty:
pass

def checkpoint(self, checkpoint: Checkpoint):
def checkpoint(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread.
Also stores the checkpoint in ``self.loaded_checkpoint``.
"""

# Update session checkpoint to latest checkpoint.
self.loaded_checkpoint = checkpoint
self.loaded_checkpoint = kwargs

# Only store checkpoints on worker with rank 0.
if self.world_rank != 0:
checkpoint = None
elif checkpoint:
checkpoint = self._encode_data_fn(checkpoint)

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=checkpoint,
metadata=self._auto_fill_checkpoint_metrics({}),
)
kwargs = {}
else:
kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs))

result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

Expand All @@ -301,23 +294,9 @@ def checkpoint(self, checkpoint: Checkpoint):

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

# Special case: early fail for Torch tensors
if "torch" in sys.modules:
from ray.air._internal.torch_utils import contains_tensor

if contains_tensor(metrics):
raise ValueError(
"Passing objects containg Torch tensors as metrics "
"is not supported as it will throw an exception on "
"deserialization. You can either convert the tensors "
"to Python objects or use a `TorchCheckpoint` as the "
"`checkpoint` argument of `ray.air.session.report` to "
"store your Torch objects."
)

if checkpoint:
self.checkpoint(checkpoint)
checkpoint_dict = checkpoint.to_dict()
self.checkpoint(**checkpoint_dict)
self._report_legacy(**metrics)


Expand Down
65 changes: 3 additions & 62 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,16 @@
import logging
import warnings
from typing import Type, TypeVar, Dict
from typing import TypeVar, Dict

from ray.air.checkpoint import Checkpoint
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.util.annotations import DeveloperAPI

from ray.widgets import make_table_html_repr

EncodedData = TypeVar("EncodedData")

logger = logging.getLogger(__name__)

# This is used in several places to print a warning.
_encode_decode_deprecation_message = (
"``encode_data`` and ``decode_data`` are deprecated in favor of "
"framework-specific ``ray.air.Checkpoint`` subclasses (reported "
"using ``ray.air.session.report()``) which can implement "
"encoding and decoding logic. In the future, ``encode_data`` and "
"``decode_data`` will throw an exception if overriden."
)


def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]):
return
# Do not print warnings in 2.1 yet.
# TODO(ml-team): Change this once we have full API parity with framework
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning
# warnings.warn(
# f"You have reported a checkpoint with the `{Checkpoint}` "
# "type, but the intended checkpoint type for the Trainer "
# f"you are using is `{expected_checkpoint_cls}`. "
# "Not using the intended checkpoint type may cause "
# "exceptions or other issues, especially during "
# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )


@DeveloperAPI
class BackendConfig:
Expand Down Expand Up @@ -73,37 +46,6 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass

@classmethod
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``encode_data`` is deprecated."""
if cls.encode_data != Backend.encode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
# We wrap the return of encode_data in dict in case it is
# not a dict itself.
checkpoint = checkpoint.from_dict(
{"encoded_data": cls.encode_data(checkpoint.to_dict())}
)
return checkpoint

@classmethod
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``decode_data`` is deprecated."""
if cls.decode_data != Backend.decode_data:
warnings.warn(
_encode_decode_deprecation_message, DeprecationWarning, stacklevel=2
)
checkpoint_dict = checkpoint.to_dict()
# If "encoded_data" is not in the dict, then the data was
# not encoded, but the user may want to just do decoding
# anyway.
checkpoint = checkpoint.from_dict(
cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict))
)
return checkpoint

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
Expand All @@ -114,7 +56,6 @@ def encode_data(data_dict: Dict) -> EncodedData:

return data_dict

@Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
Expand Down
6 changes: 1 addition & 5 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint
Expand Down Expand Up @@ -42,10 +41,7 @@ def __init__(
)

def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint):
if isinstance(checkpoint.dir_or_data, dict):
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
else:
save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data)
checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor
super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint(
checkpoint=checkpoint
)
Expand Down
9 changes: 5 additions & 4 deletions python/ray/train/examples/horovod/horovod_pytorch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from torchvision import datasets, transforms

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
from ray.train.horovod import HorovodTrainer
from ray.train.torch.torch_checkpoint import TorchCheckpoint
import ray.train.torch


Expand Down Expand Up @@ -152,11 +152,12 @@ def train_func(config):
model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda
)
if save_model_as_dict:
checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
checkpoint_dict = dict(model=model.state_dict())
else:
checkpoint = TorchCheckpoint.from_model(model)
checkpoint_dict = dict(model=model)
checkpoint_dict = Checkpoint.from_dict(checkpoint_dict)
results.append(loss)
session.report(dict(loss=loss), checkpoint=checkpoint)
session.report(dict(loss=loss), checkpoint=checkpoint_dict)

# Only used for testing.
return results
Expand Down
11 changes: 6 additions & 5 deletions python/ray/train/examples/pytorch/torch_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import ray.train as train
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig


Expand Down Expand Up @@ -48,7 +48,8 @@ def validate_epoch(dataloader, model, loss_fn):
import copy

model_copy = copy.deepcopy(model)
return model_copy.cpu().state_dict(), loss
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result


def train_func(config):
Expand All @@ -75,12 +76,12 @@ def train_func(config):
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

results = []

for _ in range(epochs):
train_epoch(train_loader, model, loss_fn, optimizer)
state_dict, loss = validate_epoch(validation_loader, model, loss_fn)
result = dict(loss=loss)
result = validate_epoch(validation_loader, model, loss_fn)
results.append(result)
session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict))
session.report(result)

return results

Expand Down
Loading

0 comments on commit 6618d53

Please sign in to comment.