diff --git a/python/ray/air/_internal/tensorflow_utils.py b/python/ray/air/_internal/tensorflow_utils.py index 61d38292e6d9..bbb08efc2b83 100644 --- a/python/ray/air/_internal/tensorflow_utils.py +++ b/python/ray/air/_internal/tensorflow_utils.py @@ -64,26 +64,6 @@ def convert_ndarray_batch_to_tf_tensor_batch( return batch -# This is not foolproof, but it's better than nothing -# The place it is used in will be deprecated soon -def contains_tensorflow_object(obj): - if hasattr(obj, "__module__") and ( - "keras" in obj.__module__ or "tensorflow" in obj.__module__ - ): - return True - elif isinstance(obj, dict): - for k, v in obj.items(): - if contains_tensorflow_object(k): - return True - if contains_tensorflow_object(v): - return True - elif isinstance(obj, (list, tuple)): - for v in obj: - if contains_tensorflow_object(v): - return True - return False - - def get_type_spec( schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"], columns: Union[str, List[str]], diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 01d62f04b3dc..9ef8a88780a4 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -501,11 +501,6 @@ def _get_temporary_checkpoint_dir(self) -> str: ) return os.path.join(tmp_dir_path, checkpoint_dir_name) - def _save_checkpoint_metadata_in_directory(self, path: str) -> None: - checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) - with open(checkpoint_metadata_path, "wb") as file: - pickle.dump(self._metadata, file) - def _to_directory(self, path: str) -> None: if self._data_dict or self._obj_ref: # This is a object ref or dict @@ -552,7 +547,9 @@ def _to_directory(self, path: str) -> None: f"No valid location found for checkpoint {self}: {self._uri}" ) - self._save_checkpoint_metadata_in_directory(path) + checkpoint_metadata_path = os.path.join(path, _CHECKPOINT_METADATA_FILE_NAME) + with open(checkpoint_metadata_path, "wb") as file: + pickle.dump(self._metadata, file) def to_directory(self, path: Optional[str] = None) -> str: """Write checkpoint data to directory. diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index e8d82388a920..855ce43d0f60 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -346,7 +346,7 @@ def initialize_session( train_func=train_func, dataset_shard=self.dataset_shards[index], checkpoint=checkpoint, - encode_data_fn=self._backend._encode_data, + encode_data_fn=self._backend.encode_data, ) ) diff --git a/python/ray/train/_internal/checkpoint.py b/python/ray/train/_internal/checkpoint.py index 6e18fedbff9c..55a11abdd400 100644 --- a/python/ray/train/_internal/checkpoint.py +++ b/python/ray/train/_internal/checkpoint.py @@ -101,15 +101,13 @@ def _process_checkpoint( """Ray Train entrypoint. Perform all processing for a checkpoint.""" # Get checkpoint from first worker. checkpoint_data = checkpoint_results[0].data - checkpoint_metadata = checkpoint_results[0].metadata or {} - # TODO(ml-team): Remove once we remove Backend.decode_data - checkpoint_data = decode_checkpoint_fn(checkpoint_data).to_dict() + # Decode checkpoint. + checkpoint_data = decode_checkpoint_fn(checkpoint_data) score_attr = self._checkpoint_strategy.checkpoint_score_attribute if ( self._checkpoint_strategy.num_to_keep != 0 - and score_attr not in checkpoint_metadata and score_attr not in checkpoint_data ): raise ValueError( @@ -124,11 +122,7 @@ def _process_checkpoint( dir_or_data=checkpoint_data, checkpoint_id=self._latest_checkpoint_id, storage_mode=CheckpointStorage.MEMORY, - metrics={ - score_attr: checkpoint_metadata.get( - score_attr, checkpoint_data.get(score_attr, 0.0) - ) - }, + metrics={score_attr: checkpoint_data.get(score_attr, 0.0)}, ) self.register_checkpoint(checkpoint=tracked_checkpoint) diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index a8a86a594083..9ab1717615a3 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -2,7 +2,6 @@ import logging import platform import queue -import sys import threading import time from dataclasses import dataclass @@ -52,8 +51,7 @@ class TrialInfo: @dataclass class TrainingResult: type: TrainingResultType - data: Union[Dict, Checkpoint] - metadata: Optional[Dict] = None + data: Dict # TODO(xwjiang): This needs a better name. @@ -70,9 +68,8 @@ def __init__( trial_info: Optional[TrialInfo] = None, dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None, # TODO(xwjiang): Legacy Ray Train trainer clean up! - checkpoint: Optional[Checkpoint] = None, - # Deprecated - encode_data_fn: Optional[Callable] = None, + checkpoint: Optional[Union[Dict, Checkpoint]] = None, + encode_data_fn: Callable = None, detailed_autofilled_metrics: bool = False, ): @@ -83,7 +80,7 @@ def __init__( self.world_size = world_size self.trial_info = trial_info # TODO(xwjiang): Legacy Ray Train trainer clean up! - self.loaded_checkpoint = checkpoint + self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint # Function to encode checkpoint dict before sending to the driver. if not encode_data_fn: @@ -243,9 +240,9 @@ def _report_legacy(self, **kwargs): if self.ignore_report: return - kwargs = self._auto_fill_metrics(kwargs) + kwargs = self._encode_data_fn(self._auto_fill_metrics(kwargs)) - result = TrainingResult(type=TrainingResultType.REPORT, data=kwargs) + result = TrainingResult(TrainingResultType.REPORT, kwargs) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -272,26 +269,22 @@ def _report_thread_runner_error(self, block=False): except queue.Empty: pass - def checkpoint(self, checkpoint: Checkpoint): + def checkpoint(self, **kwargs): """Adds kwargs to the queue to be consumed by main thread. Also stores the checkpoint in ``self.loaded_checkpoint``. """ # Update session checkpoint to latest checkpoint. - self.loaded_checkpoint = checkpoint + self.loaded_checkpoint = kwargs # Only store checkpoints on worker with rank 0. if self.world_rank != 0: - checkpoint = None - elif checkpoint: - checkpoint = self._encode_data_fn(checkpoint) - - result = TrainingResult( - type=TrainingResultType.CHECKPOINT, - data=checkpoint, - metadata=self._auto_fill_checkpoint_metrics({}), - ) + kwargs = {} + else: + kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs)) + + result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs) # Add result to a thread-safe queue. self.result_queue.put(result, block=True) @@ -301,23 +294,9 @@ def checkpoint(self, checkpoint: Checkpoint): def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None: # TODO(xwjiang): tons of optimizations. - - # Special case: early fail for Torch tensors - if "torch" in sys.modules: - from ray.air._internal.torch_utils import contains_tensor - - if contains_tensor(metrics): - raise ValueError( - "Passing objects containg Torch tensors as metrics " - "is not supported as it will throw an exception on " - "deserialization. You can either convert the tensors " - "to Python objects or use a `TorchCheckpoint` as the " - "`checkpoint` argument of `ray.air.session.report` to " - "store your Torch objects." - ) - if checkpoint: - self.checkpoint(checkpoint) + checkpoint_dict = checkpoint.to_dict() + self.checkpoint(**checkpoint_dict) self._report_legacy(**metrics) diff --git a/python/ray/train/backend.py b/python/ray/train/backend.py index 7d1ae6848d6f..78dde0c36836 100644 --- a/python/ray/train/backend.py +++ b/python/ray/train/backend.py @@ -1,43 +1,16 @@ import logging -import warnings -from typing import Type, TypeVar, Dict +from typing import TypeVar, Dict -from ray.air.checkpoint import Checkpoint from ray.train._internal.utils import Singleton from ray.train._internal.worker_group import WorkerGroup -from ray.util.annotations import Deprecated, DeveloperAPI +from ray.util.annotations import DeveloperAPI + from ray.widgets import make_table_html_repr EncodedData = TypeVar("EncodedData") logger = logging.getLogger(__name__) -# This is used in several places to print a warning. -_encode_decode_deprecation_message = ( - "``encode_data`` and ``decode_data`` are deprecated in favor of " - "framework-specific ``ray.air.Checkpoint`` subclasses (reported " - "using ``ray.air.session.report()``) which can implement " - "encoding and decoding logic. In the future, ``encode_data`` and " - "``decode_data`` will throw an exception if overriden." -) - - -def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]): - return - # Do not print warnings in 2.1 yet. - # TODO(ml-team): Change this once we have full API parity with framework - # checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning - # warnings.warn( - # f"You have reported a checkpoint with the `{Checkpoint}` " - # "type, but the intended checkpoint type for the Trainer " - # f"you are using is `{expected_checkpoint_cls}`. " - # "Not using the intended checkpoint type may cause " - # "exceptions or other issues, especially during " - # "serialization and deserialization. The checkpoint " - # "type will be changed automatically. " - # "This behavior may change in the future." - # ) - @DeveloperAPI class BackendConfig: @@ -73,37 +46,6 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig): """Logic for shutting down the backend.""" pass - @classmethod - def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint: - """Temporary method until ``encode_data`` is deprecated.""" - if cls.encode_data != Backend.encode_data: - warnings.warn( - _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 - ) - # We wrap the return of encode_data in dict in case it is - # not a dict itself. - checkpoint = checkpoint.from_dict( - {"encoded_data": cls.encode_data(checkpoint.to_dict())} - ) - return checkpoint - - @classmethod - def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint: - """Temporary method until ``decode_data`` is deprecated.""" - if cls.decode_data != Backend.decode_data: - warnings.warn( - _encode_decode_deprecation_message, DeprecationWarning, stacklevel=2 - ) - checkpoint_dict = checkpoint.to_dict() - # If "encoded_data" is not in the dict, then the data was - # not encoded, but the user may want to just do decoding - # anyway. - checkpoint = checkpoint.from_dict( - cls.decode_data(checkpoint_dict.get("encoded_data", checkpoint_dict)) - ) - return checkpoint - - @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def encode_data(data_dict: Dict) -> EncodedData: """Logic to encode a data dict before sending to the driver. @@ -114,7 +56,6 @@ def encode_data(data_dict: Dict) -> EncodedData: return data_dict - @Deprecated(message=_encode_decode_deprecation_message) @staticmethod def decode_data(encoded_data: EncodedData) -> Dict: """Logic to decode an encoded data dict. diff --git a/python/ray/train/data_parallel_trainer.py b/python/ray/train/data_parallel_trainer.py index 06a828605853..ca2c595e4190 100644 --- a/python/ray/train/data_parallel_trainer.py +++ b/python/ray/train/data_parallel_trainer.py @@ -8,7 +8,6 @@ from ray import tune from ray.air import session from ray.air.checkpoint import Checkpoint -from ray.air._internal.checkpointing import save_preprocessor_to_dir from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.air._internal.checkpoint_manager import _TrackedCheckpoint @@ -42,10 +41,7 @@ def __init__( ) def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): - if isinstance(checkpoint.dir_or_data, dict): - checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor - else: - save_preprocessor_to_dir(self.preprocessor, checkpoint.dir_or_data) + checkpoint.dir_or_data[PREPROCESSOR_KEY] = self.preprocessor super(_DataParallelCheckpointManager, self)._process_persistent_checkpoint( checkpoint=checkpoint ) diff --git a/python/ray/train/examples/horovod/horovod_pytorch_example.py b/python/ray/train/examples/horovod/horovod_pytorch_example.py index f4d15ae0515b..5197d900aba6 100644 --- a/python/ray/train/examples/horovod/horovod_pytorch_example.py +++ b/python/ray/train/examples/horovod/horovod_pytorch_example.py @@ -9,9 +9,9 @@ from torchvision import datasets, transforms from ray.air import session +from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig from ray.train.horovod import HorovodTrainer -from ray.train.torch.torch_checkpoint import TorchCheckpoint import ray.train.torch @@ -152,11 +152,12 @@ def train_func(config): model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda ) if save_model_as_dict: - checkpoint = TorchCheckpoint.from_state_dict(model.state_dict()) + checkpoint_dict = dict(model=model.state_dict()) else: - checkpoint = TorchCheckpoint.from_model(model) + checkpoint_dict = dict(model=model) + checkpoint_dict = Checkpoint.from_dict(checkpoint_dict) results.append(loss) - session.report(dict(loss=loss), checkpoint=checkpoint) + session.report(dict(loss=loss), checkpoint=checkpoint_dict) # Only used for testing. return results diff --git a/python/ray/train/examples/pytorch/torch_linear_example.py b/python/ray/train/examples/pytorch/torch_linear_example.py index 647b51a0db6b..03cedc5e0751 100644 --- a/python/ray/train/examples/pytorch/torch_linear_example.py +++ b/python/ray/train/examples/pytorch/torch_linear_example.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import ray.train as train -from ray.train.torch import TorchTrainer, TorchCheckpoint +from ray.train.torch import TorchTrainer from ray.air.config import ScalingConfig @@ -48,7 +48,8 @@ def validate_epoch(dataloader, model, loss_fn): import copy model_copy = copy.deepcopy(model) - return model_copy.cpu().state_dict(), loss + result = {"model": model_copy.cpu().state_dict(), "loss": loss} + return result def train_func(config): @@ -75,12 +76,12 @@ def train_func(config): optimizer = torch.optim.SGD(model.parameters(), lr=lr) results = [] + for _ in range(epochs): train_epoch(train_loader, model, loss_fn, optimizer) - state_dict, loss = validate_epoch(validation_loader, model, loss_fn) - result = dict(loss=loss) + result = validate_epoch(validation_loader, model, loss_fn) results.append(result) - session.report(result, checkpoint=TorchCheckpoint.from_state_dict(state_dict)) + session.report(result) return results diff --git a/python/ray/train/horovod/config.py b/python/ray/train/horovod/config.py index 5960137f45fa..989c8f06f6d6 100644 --- a/python/ray/train/horovod/config.py +++ b/python/ray/train/horovod/config.py @@ -1,20 +1,18 @@ import sys -from typing import Optional, Set +from typing import Optional, Set, Dict import os from dataclasses import dataclass import ray -from ray.air.checkpoint import Checkpoint -from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type +from ray.air._internal.torch_utils import contains_tensor +from ray.train.backend import BackendConfig, Backend, EncodedData from ray.train._internal.utils import update_env_vars from ray.train._internal.worker_group import WorkerGroup, Worker from horovod.ray.runner import Coordinator from horovod.ray.utils import detect_nics, nics_to_env_var from horovod.runner.common.util import secret, timeout -from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint -from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI @@ -133,26 +131,36 @@ def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig): worker_group.execute(update_env_vars, coordinator_envs) - @classmethod - def _encode_data(cls, checkpoint: Checkpoint): - checkpoint = super()._encode_data(checkpoint) - if type(checkpoint) is Checkpoint: - if checkpoint.get_internal_representation()[0] == "data_dict": - if "tensorflow" in sys.modules: - from ray.air._internal.tensorflow_utils import ( - contains_tensorflow_object, - ) - - if contains_tensorflow_object(checkpoint.to_dict()): - _warn_about_bad_checkpoint_type(TensorflowCheckpoint) - checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) - if "torch" in sys.modules: - from ray.air._internal.torch_utils import contains_tensor - - if contains_tensor(checkpoint.to_dict()): - _warn_about_bad_checkpoint_type(TorchCheckpoint) - checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) - return checkpoint + @staticmethod + def encode_data(data_dict: Dict) -> EncodedData: + """Logic to encode a data dict before sending to the driver. + + This function will be called on the workers for any data that is + sent to the driver via ``session.report()``. + """ + # If torch is imported, we can use it to serialize the data dict + # into bytes. This will prevent e.g. GPU deserialization errors. + if "torch" in sys.modules and contains_tensor(data_dict): + from ray.train.torch.config import _TorchBackend + + return _TorchBackend.encode_data(data_dict) + + return data_dict + + @staticmethod + def decode_data(encoded_data: EncodedData) -> Dict: + """Logic to decode an encoded data dict. + + This function will be called on the driver after receiving the + encoded data dict from the worker. + """ + # See encode_data + if "torch" in sys.modules and isinstance(encoded_data, bytes): + from ray.train.torch.config import _TorchBackend + + return _TorchBackend.decode_data(encoded_data) + + return encoded_data def _init_env_vars(world_rank: int, world_size: int, node_id: str): diff --git a/python/ray/train/horovod/horovod_trainer.py b/python/ray/train/horovod/horovod_trainer.py index c4a4c957cc6b..beded4a43bb8 100644 --- a/python/ray/train/horovod/horovod_trainer.py +++ b/python/ray/train/horovod/horovod_trainer.py @@ -87,9 +87,8 @@ def train_loop_per_worker(): import horovod.torch as hvd import torch import torch.nn as nn - from ray.air import session + from ray.air import session, Checkpoint from ray.train.horovod import HorovodTrainer - from ray.train.torch import TorchCheckpoint from ray.air.config import ScalingConfig input_size = 1 @@ -137,8 +136,8 @@ def train_loop_per_worker(): print(f"epoch: {epoch}, loss: {loss.item()}") session.report( {}, - checkpoint=TorchCheckpoint.from_state_dict( - model.state_dict() + checkpoint=Checkpoint.from_dict( + dict(model=model.state_dict()) ), ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) diff --git a/python/ray/train/huggingface/_huggingface_utils.py b/python/ray/train/huggingface/_huggingface_utils.py index 8b61590147a8..623b7ca00ef4 100644 --- a/python/ray/train/huggingface/_huggingface_utils.py +++ b/python/ray/train/huggingface/_huggingface_utils.py @@ -7,9 +7,9 @@ from transformers.trainer_utils import IntervalStrategy from ray.air import session +from ray.air.checkpoint import Checkpoint from ray.util import get_node_ip_address from ray.data.dataset import Dataset -from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint if TYPE_CHECKING: from torch.utils.data import IterableDataset @@ -152,8 +152,7 @@ def on_save(self, args, state, control, **kwargs): transformers.trainer.get_last_checkpoint(args.output_dir) ).absolute() if checkpoint_path: - # Use HuggingFaceCheckpoint here to avoid a warning in _TrainSession - self.delayed_report["checkpoint"] = HuggingFaceCheckpoint.from_dict( + self.delayed_report["checkpoint"] = Checkpoint.from_dict( { NODE_IP_KEY: get_node_ip_address(), CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), diff --git a/python/ray/train/huggingface/huggingface_trainer.py b/python/ray/train/huggingface/huggingface_trainer.py index 9ca24ce3ea9e..400505684a28 100644 --- a/python/ray/train/huggingface/huggingface_trainer.py +++ b/python/ray/train/huggingface/huggingface_trainer.py @@ -7,7 +7,6 @@ import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type -from ray.train.huggingface.huggingface_checkpoint import HuggingFaceCheckpoint try: from packaging.version import Version @@ -130,13 +129,6 @@ def commit(self, path: Optional[Path] = None) -> None: with open(path.joinpath(TUNE_CHECKPOINT_ID), "w") as f: f.write(str(self.id)) - # Add checkpoint class metadata - # A bit of a hack but this will be removed with the rest - # of this special case eventually - # TODO(ml-team): remove this when HF checkpointing is refactored - checkpoint = HuggingFaceCheckpoint.from_directory(path) - checkpoint._save_checkpoint_metadata_in_directory(path) - class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): def _process_persistent_checkpoint(self, checkpoint: _TrackedCheckpoint): diff --git a/python/ray/train/tensorflow/config.py b/python/ray/train/tensorflow/config.py index f16caf88b420..88ccafbb839b 100644 --- a/python/ray/train/tensorflow/config.py +++ b/python/ray/train/tensorflow/config.py @@ -5,11 +5,9 @@ from typing import List import ray -from ray.air.checkpoint import Checkpoint -from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type +from ray.train.backend import BackendConfig, Backend from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup -from ray.train.tensorflow.tensorflow_checkpoint import TensorflowCheckpoint from ray.util import PublicAPI @@ -58,11 +56,3 @@ def get_url(): ) ) ray.get(setup_futures) - - @classmethod - def _encode_data(cls, checkpoint: Checkpoint): - checkpoint = super()._encode_data(checkpoint) - if type(checkpoint) is Checkpoint: - _warn_about_bad_checkpoint_type(TensorflowCheckpoint) - checkpoint = TensorflowCheckpoint.from_checkpoint(checkpoint) - return checkpoint diff --git a/python/ray/train/tests/test_gpu_amp.py b/python/ray/train/tests/test_gpu_amp.py index 2e8f9522b248..455f65db106b 100644 --- a/python/ray/train/tests/test_gpu_amp.py +++ b/python/ray/train/tests/test_gpu_amp.py @@ -65,7 +65,7 @@ def train_func(): model = torchvision.models.resnet101() model = train.torch.prepare_model(model) - session.report({}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=2, use_gpu=True) diff --git a/python/ray/train/tests/test_gpu_auto_transfer.py b/python/ray/train/tests/test_gpu_auto_transfer.py index 3cea67266a49..242dbe7644c1 100644 --- a/python/ray/train/tests/test_gpu_auto_transfer.py +++ b/python/ray/train/tests/test_gpu_auto_transfer.py @@ -8,7 +8,8 @@ from ray.air import session from ray.air.constants import MODEL_KEY from ray.air.config import ScalingConfig -from ray.train.torch import TorchTrainer, TorchCheckpoint +from ray.train.torch.torch_checkpoint import TorchCheckpoint +from ray.train.torch.torch_trainer import TorchTrainer import ray.train.torch.train_loop_utils @@ -105,7 +106,7 @@ def train_func(): assert next(model.parameters()).is_cuda - session.report({}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_func, scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=True) @@ -113,7 +114,9 @@ def train_func(): results = trainer.fit() model_checkpoint = results.checkpoint.get_model() + model_report = results.metrics["model"] assert not next(model_checkpoint.parameters()).is_cuda + assert not next(model_report.parameters()).is_cuda # Test the same thing for state dict. @@ -131,7 +134,7 @@ def train_func(): assert tensor.is_cuda session.report( - {}, + {"state_dict": state_dict}, checkpoint=TorchCheckpoint.from_state_dict(state_dict), ) @@ -141,6 +144,10 @@ def train_func(): results = trainer.fit() state_dict_checkpoint = results.checkpoint.to_dict()[MODEL_KEY] + state_dict_report = results.metrics["state_dict"] + + for tensor in state_dict_report.values(): + assert not tensor.is_cuda for tensor in state_dict_checkpoint.values(): assert not tensor.is_cuda diff --git a/python/ray/train/tests/test_huggingface_trainer.py b/python/ray/train/tests/test_huggingface_trainer.py index 4951f6444370..b741d3475574 100644 --- a/python/ray/train/tests/test_huggingface_trainer.py +++ b/python/ray/train/tests/test_huggingface_trainer.py @@ -11,11 +11,7 @@ import ray.data from ray.exceptions import RayTaskError from ray.train.batch_predictor import BatchPredictor -from ray.train.huggingface import ( - HuggingFacePredictor, - HuggingFaceTrainer, - HuggingFaceCheckpoint, -) +from ray.train.huggingface import HuggingFacePredictor, HuggingFaceTrainer from ray.air.config import ScalingConfig from ray.train.tests._huggingface_data import train_data, validation_data from ray import tune @@ -95,7 +91,6 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result.metrics["epoch"] == 4 assert result.metrics["training_iteration"] == 4 assert result.checkpoint - assert isinstance(result.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result.metrics trainer2 = HuggingFaceTrainer( @@ -113,7 +108,6 @@ def test_e2e(ray_start_4_cpus, save_strategy): assert result2.metrics["epoch"] == 5 assert result2.metrics["training_iteration"] == 1 assert result2.checkpoint - assert isinstance(result2.checkpoint, HuggingFaceCheckpoint) assert "eval_loss" in result2.metrics predictor = BatchPredictor.from_checkpoint( diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index 545d611f4882..9eb5dd51985b 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -139,7 +139,7 @@ def validate_zero(expected): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert next.data.to_dict()["epoch"] == expected + assert next.data["epoch"] == expected init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() @@ -155,7 +155,7 @@ def validate_nonzero(): next = session.get_next() assert next is not None assert next.type == TrainingResultType.CHECKPOINT - assert not next.data + assert next.data == {} init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1) session = get_session() @@ -173,17 +173,13 @@ def train_func(): report(dict(epoch=0), checkpoint=Checkpoint.from_dict(dict(epoch=0))) def encode_checkpoint(checkpoint): - data = checkpoint.to_dict() - data["encoded"] = True - return checkpoint.from_dict(data) + checkpoint.update({"encoded": True}) + return checkpoint def validate_encoded(result_type: TrainingResultType): next = session.get_next() assert next.type is result_type - data = next.data - if isinstance(data, Checkpoint): - data = data.to_dict() - assert data["encoded"] is True + assert next.data["encoded"] is True init_session( training_func=train_func, @@ -197,7 +193,8 @@ def validate_encoded(result_type: TrainingResultType): session.start() # Validate checkpoint is encoded. validate_encoded(TrainingResultType.CHECKPOINT) - session.get_next() + # Validate report is encoded. + validate_encoded(TrainingResultType.REPORT) session.finish() shutdown_session() @@ -214,7 +211,6 @@ def train_func(): session.start() for i in range(2): session.get_next() - session.get_next() session.finish() shutdown_session() diff --git a/python/ray/train/tests/test_tensorflow_trainer.py b/python/ray/train/tests/test_tensorflow_trainer.py index df87b665b5c6..fc7ee1f97a9f 100644 --- a/python/ray/train/tests/test_tensorflow_trainer.py +++ b/python/ray/train/tests/test_tensorflow_trainer.py @@ -1,4 +1,6 @@ import os + +import numpy as np import pytest import ray @@ -8,14 +10,12 @@ train_func as tensorflow_linear_train_func, ) from ray.air.config import ScalingConfig -from ray.train.batch_predictor import BatchPredictor from ray.train.constants import TRAIN_DATASET_KEY from ray.train.tensorflow import ( + TensorflowCheckpoint, TensorflowPredictor, TensorflowTrainer, - TensorflowCheckpoint, ) -from ray.train.tests.dummy_preprocessor import DummyPreprocessor @pytest.fixture @@ -74,19 +74,23 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - preprocessor=DummyPreprocessor(), + train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() - assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) - batch_predictor = BatchPredictor.from_checkpoint( - result.checkpoint, TensorflowPredictor, model_definition=build_model - ) + class TensorflowScorer: + def __init__(self): + self.pred = TensorflowPredictor.from_checkpoint( + result.checkpoint, build_model + ) + + def __call__(self, x): + return self.pred.predict(x, dtype=np.float) predict_dataset = ray.data.range(3) - predictions = batch_predictor.predict(predict_dataset) + predictions = predict_dataset.map_batches( + TensorflowScorer, batch_format="pandas", compute="actors" + ) assert predictions.count() == 3 @@ -108,23 +112,17 @@ def train_func(): scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - preprocessor=DummyPreprocessor(), + train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() - checkpoint = result.checkpoint - assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, - resume_from_checkpoint=checkpoint, - preprocessor=DummyPreprocessor(), + resume_from_checkpoint=result.checkpoint, ) result = trainer2.fit() checkpoint = result.checkpoint - assert isinstance(checkpoint.get_preprocessor(), DummyPreprocessor) with checkpoint.as_directory() as ckpt_dir: assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb")) assert result.metrics["iter"] == 1 diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index 905717961fc7..2dbec40eb1d5 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -1,13 +1,14 @@ import contextlib import pytest +from ray.air import session +from ray.air.checkpoint import Checkpoint +from ray.train.torch.torch_checkpoint import TorchCheckpoint import torch -import os import ray from ray.train.examples.pytorch.torch_linear_example import ( train_func as linear_train_func, ) -from ray.train.batch_predictor import BatchPredictor from ray.train.torch import TorchPredictor, TorchTrainer from ray.tune import TuneError from ray.air.config import ScalingConfig @@ -15,9 +16,6 @@ import ray.train as train from unittest.mock import patch from ray.cluster_utils import Cluster -from ray.air import session -from ray.train.tests.dummy_preprocessor import DummyPreprocessor -from ray.train.torch.torch_checkpoint import TorchCheckpoint @pytest.fixture @@ -64,21 +62,25 @@ def train_func(config): def test_torch_e2e(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1) - session.report({}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - preprocessor=DummyPreprocessor(), + train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() - assert isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) predict_dataset = ray.data.range(9) - batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, TorchPredictor) - predictions = batch_predictor.predict( - predict_dataset, batch_size=3, dtype=torch.float + + class TorchScorer: + def __init__(self): + self.pred = TorchPredictor.from_checkpoint(result.checkpoint) + + def __call__(self, x): + return self.pred.predict(x, dtype=torch.float) + + predictions = predict_dataset.map_batches( + TorchScorer, batch_size=3, batch_format="pandas", compute="actors" ) assert predictions.count() == 3 @@ -86,55 +88,22 @@ def train_func(): def test_torch_e2e_state_dict(ray_start_4_cpus): def train_func(): model = torch.nn.Linear(3, 1).state_dict() - session.report({}, checkpoint=TorchCheckpoint.from_state_dict(model)) + session.report({}, checkpoint=Checkpoint.from_dict(dict(model=model))) scaling_config = ScalingConfig(num_workers=2) trainer = TorchTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - preprocessor=DummyPreprocessor(), + train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() - isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) # If loading from a state dict, a model definition must be passed in. with pytest.raises(ValueError): TorchPredictor.from_checkpoint(result.checkpoint) - predict_dataset = ray.data.range(9) - batch_predictor = BatchPredictor.from_checkpoint( - result.checkpoint, TorchPredictor, model=torch.nn.Linear(3, 1) - ) - predictions = batch_predictor.predict( - predict_dataset, batch_size=3, dtype=torch.float - ) - assert predictions.count() == 3 - - -def test_torch_e2e_dir(ray_start_4_cpus, tmpdir): - def train_func(): - model = torch.nn.Linear(3, 1) - torch.save(model, os.path.join(tmpdir, "model")) - session.report({}, checkpoint=TorchCheckpoint.from_directory(tmpdir)) - - scaling_config = ScalingConfig(num_workers=2) - trainer = TorchTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - preprocessor=DummyPreprocessor(), - ) - result = trainer.fit() - isinstance(result.checkpoint.get_preprocessor(), DummyPreprocessor) - - # TODO(ml-team): Add a way for TorchCheckpoint to natively support - # models from files class TorchScorer: def __init__(self): - with result.checkpoint.as_directory() as checkpoint_path: - model = torch.load(os.path.join(checkpoint_path, "model")) - preprocessor = result.checkpoint.get_preprocessor() self.pred = TorchPredictor.from_checkpoint( - TorchCheckpoint.from_model(model, preprocessor=preprocessor) + result.checkpoint, model=torch.nn.Linear(3, 1) ) def __call__(self, x): @@ -162,58 +131,6 @@ def test_checkpoint_freq(ray_start_4_cpus): trainer.fit() -def test_torch_session_errors(ray_start_4_cpus): - """Test fail-fast behavior when reporting dicts with Torch tensors""" - - def train_func(): - model = torch.nn.Linear(1, 1).state_dict() - with pytest.raises(ValueError): - session.report(model) - - scaling_config = ScalingConfig(num_workers=2) - trainer = TorchTrainer( - train_loop_per_worker=train_func, - scaling_config=scaling_config, - ) - trainer.fit() - - -# See comment in backend.py::_warn_about_bad_checkpoint_type -# for why test_torch_bad_checkpoint_warning is commented out - -# def test_torch_bad_checkpoint_warning(ray_start_4_cpus): -# """Test that a warning is printed if bad checkpoint type is used.""" - -# def train_func(): -# model = torch.nn.Linear(1, 1).state_dict() -# session.report({}, checkpoint=TorchCheckpoint.from_dict({"model": model})) - -# scaling_config = ScalingConfig(num_workers=2) -# trainer = TorchTrainer( -# train_loop_per_worker=train_func, -# scaling_config=scaling_config, -# ) -# output = io.StringIO() -# with redirect_stdout(output), redirect_stderr(output): -# trainer.fit() -# output = output.getvalue() -# assert "You have reported a checkpoint" not in output - -# def train_func(): -# model = torch.nn.Linear(1, 1).state_dict() -# session.report({}, checkpoint=Checkpoint.from_dict({"model": model})) - -# trainer = TorchTrainer( -# train_loop_per_worker=train_func, -# scaling_config=scaling_config, -# ) -# output = io.StringIO() -# with redirect_stdout(output), redirect_stderr(output): -# trainer.fit() -# output = output.getvalue() -# assert "You have reported a checkpoint" in output - - @pytest.mark.parametrize( "num_gpus_per_worker,expected_devices", [(0.5, [0]), (1, [0]), (2, [0, 1])] ) @@ -287,7 +204,7 @@ def train_fn(): model = train.torch.prepare_model(model) # Save DDP wrapped model. - session.report({}, checkpoint=TorchCheckpoint.from_model(model)) + session.report({"model": model}, checkpoint=TorchCheckpoint.from_model(model)) trainer = TorchTrainer( train_loop_per_worker=train_fn, @@ -301,6 +218,11 @@ def train_fn(): model, torch.nn.parallel.DistributedDataParallel ) + model_report = results.metrics["model"] + assert isinstance(model_report, torch.nn.Module) and not isinstance( + model_report, torch.nn.parallel.DistributedDataParallel + ) + def test_torch_amp(ray_start_4_cpus): def train_fn(): diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index fca7ec745169..346de18d5815 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -1,20 +1,21 @@ from dataclasses import dataclass +import io import logging import os from datetime import timedelta -from typing import Optional +from typing import Dict, Optional import ray -from ray.air.checkpoint import Checkpoint -from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type +import ray.cloudpickle +from ray.train.backend import BackendConfig, Backend, EncodedData from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.utils import get_address_and_port -from ray.train.torch.torch_checkpoint import TorchCheckpoint from ray.util import PublicAPI import torch import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel try: from torch.profiler import profile @@ -177,10 +178,32 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig): _shutdown_torch, destroy_process_group=len(worker_group) > 1 ) - @classmethod - def _encode_data(cls, checkpoint: Checkpoint): - checkpoint = super()._encode_data(checkpoint) - if type(checkpoint) is Checkpoint: - _warn_about_bad_checkpoint_type(TorchCheckpoint) - checkpoint = TorchCheckpoint.from_checkpoint(checkpoint) - return checkpoint + @staticmethod + def encode_data(data_dict: Dict) -> EncodedData: + """Special handling for moving model from worker to driver.""" + + # If model is being checkpointed and is wrapped in DDP, then extract + # out the underlying module. If not, then deserialization will fail + # since the torch process group is not initialized on the driver. + + for k, v in data_dict.items(): + if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): + data_dict[k] = v.module + + # Convert the checkpoint dict to bytes, so that any GPU tensors that + # are in the checkpoint dict can be properly deserialized on the + # driver side, even if the driver does not have access to a GPU device. + _buffer = io.BytesIO() + # If a custom torch model contains a function that cannot be pickled normally, + # we need to use ray.cloudpickle. This is also consistent with how Ray + # serialization works in general and has no downsides + # (this can still be unpickled without ray using normal pickle). + torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) + return _buffer.getvalue() + + @staticmethod + def decode_data(encoded_data: EncodedData) -> Dict: + # When decoding the bytes on the driver side, always map to CPU. + _buffer = io.BytesIO(encoded_data) + checkpoint_dict = torch.load(_buffer, map_location="cpu") + return checkpoint_dict diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 9d040d99060c..774256965c8b 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -1,10 +1,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional -import io import torch import warnings -import ray.cloudpickle from ray.air.checkpoint import Checkpoint from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY from ray.train.data_parallel_trainer import _load_checkpoint_dict @@ -14,8 +12,6 @@ if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor -ENCODED_DATA_KEY = "torch_encoded_data" - @PublicAPI(stability="beta") class TorchCheckpoint(Checkpoint): @@ -25,49 +21,6 @@ class TorchCheckpoint(Checkpoint): ``TorchCheckpoint.from_checkpoint(ckpt)``. """ - # Special encoding logic to avoid serialization errors with torch. - def _encode_data_dict(self, data_dict: dict) -> dict: - """Encode data_dict using torch.save.""" - from torch.nn.parallel import DistributedDataParallel - - for k, v in data_dict.items(): - if isinstance(v, DistributedDataParallel) and hasattr(v, "module"): - data_dict[k] = v.module - - # Convert the checkpoint dict to bytes, so that any GPU tensors that - # are in the checkpoint dict can be properly deserialized on the - # driver side, even if the driver does not have access to a GPU device. - _buffer = io.BytesIO() - torch.save(data_dict, _buffer, pickle_module=ray.cloudpickle) - return {ENCODED_DATA_KEY: _buffer.getvalue()} - - def _decode_data_dict(self, data_dict: dict) -> dict: - """Decode data_dict using torch_load if needed.""" - if ENCODED_DATA_KEY not in data_dict: - return data_dict - encoded_data = data_dict[ENCODED_DATA_KEY] - _buffer = io.BytesIO(encoded_data) - data_dict = torch.load( - _buffer, - map_location="cpu" - # Not using ray.cloudpickle here as it doesn't - # define an Unpickler (as it is not necessary). - ) - return data_dict - - def __getstate__(self) -> dict: - if self._data_dict: - state = self.__dict__.copy() - state["_data_dict"] = self._encode_data_dict(self._data_dict) - return state - return super().__getstate__() - - def __setstate__(self, state: dict): - if "_data_dict" in state: - state = state.copy() - state["_data_dict"] = self._decode_data_dict(state["_data_dict"]) - super().__setstate__(state) - @classmethod def from_state_dict( cls, diff --git a/python/ray/train/torch/torch_trainer.py b/python/ray/train/torch/torch_trainer.py index 270d130badb9..bbd5848c79b9 100644 --- a/python/ray/train/torch/torch_trainer.py +++ b/python/ray/train/torch/torch_trainer.py @@ -142,7 +142,7 @@ def train_loop_per_worker(): session.report( {}, checkpoint=Checkpoint.from_dict( - dict(epoch=epoch, model=model.state_dict() + dict(epoch=epoch, model=model.state_dict()) ), ) diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 9aadf2930f09..5ede3472a331 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -260,11 +260,11 @@ def _fetch_next_result(self) -> Optional[List[Dict]]: first_result = results[0] result_type = first_result.type if result_type is TrainingResultType.REPORT: - result_data = [r.data for r in results] + result_data = [self._backend.decode_data(r.data) for r in results] return result_data elif result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend._decode_data + results, decode_checkpoint_fn=self._backend.decode_data ) # Iterate until next REPORT call or training has finished. else: @@ -284,7 +284,7 @@ def _finish_checkpointing(self): # Process checkpoints and ignore other result types. if result_type is TrainingResultType.CHECKPOINT: self._checkpoint_manager._process_checkpoint( - results, decode_checkpoint_fn=self._backend._decode_data + results, decode_checkpoint_fn=self._backend.decode_data ) def _finish_training(self):