diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index b3fadeccb0685..a532553d82fed 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -7,109 +7,74 @@ Training (tune.Trainable, session.report) ========================================== -Training can be done with either a **Function API** (``session.report``) or **Class API** (``tune.Trainable``). +Training can be done with either a **Function API** (:ref:`session.report `) or **Class API** (:ref:`tune.Trainable `). For the sake of example, let's maximize this objective function: -.. code-block:: python - - def objective(x, a, b): - return a * (x ** 0.5) + b +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __example_objective_start__ + :end-before: __example_objective_end__ .. _tune-function-api: Function API ------------ -With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the provided function. +The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes, +one for each Tune trial. +With the Function API, you can report intermediate metrics by simply calling ``session.report`` within the function. -.. code-block:: python +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_report_intermediate_metrics_start__ + :end-before: __function_api_report_intermediate_metrics_end__ - from ray import tune - from ray.air import session +.. tip:: Do not use ``session.report`` within a ``Trainable`` class. - def trainable(config): - # config (dict): A dict of hyperparameters. +In the previous example, we reported on every step, but this metric reporting frequency +is configurable. For example, we could also report only a single time at the end with the final score: - for x in range(20): - intermediate_score = objective(x, config["a"], config["b"]) +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_report_final_metrics_start__ + :end-before: __function_api_report_final_metrics_end__ - session.report({"score": intermediate_score}) # This sends the score to Tune. +It's also possible to return a final set of metrics to Tune by returning them from your function: - tuner = tune.Tuner( - trainable, - param_space={"a": 2, "b": 4} - ) - results = tuner.fit() - - print("best config: ", results.get_best_result(metric="score", mode="max").config) - -.. tip:: Do not use ``session.report`` within a ``Trainable`` class. - -Tune will run this function on a separate thread in a Ray actor process. +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_return_final_metrics_start__ + :end-before: __function_api_return_final_metrics_end__ You'll notice that Ray Tune will output extra values in addition to the user reported metrics, such as ``iterations_since_restore``. See :ref:`tune-autofilled-metrics` for an explanation/glossary of these values. -.. code-block:: python - - def trainable(config): - # config (dict): A dict of hyperparameters. - - final_score = 0 - for x in range(20): - final_score = objective(x, config["a"], config["b"]) - - return {"score": final_score} # This sends the score to Tune. - - tuner = tune.Tuner( - trainable, - param_space={"a": 2, "b": 4} - ) - results = tuner.fit() - - print("best config: ", results.get_best_result(metric="score", mode="max").config) - - .. _tune-function-checkpointing: Function API Checkpointing ~~~~~~~~~~~~~~~~~~~~~~~~~~ Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. -You can save and load checkpoint in Ray Tune in the following manner: - -.. code-block:: python - - import time - from ray import tune - from ray.air import session - from ray.air.checkpoint import Checkpoint +You can save and load checkpoints in Ray Tune in the following manner: - def train_func(config): - step = 0 - loaded_checkpoint = session.get_checkpoint() - if loaded_checkpoint: - last_step = loaded_checkpoint.to_dict()["step"] - step = last_step + 1 - - for iter in range(step, 100): - time.sleep(1) - - checkpoint = Checkpoint.from_dict({"step": step}) - session.report({"message": "Hello world Ray Tune!"}, checkpoint=checkpoint) - - tuner = tune.Tuner(train_func) - results = tuner.fit() +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __function_api_checkpointing_start__ + :end-before: __function_api_checkpointing_end__ .. note:: ``checkpoint_frequency`` and ``checkpoint_at_end`` will not work with Function API checkpointing. -In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. +In this example, checkpoints will be saved by training iteration to ``//trial_name/checkpoint_``. Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``. +See :ref:`here for more information on creating checkpoints `. +If using framework-specific trainers from Ray AIR, see :ref:`here ` for +references to framework-specific checkpoints such as `TensorflowCheckpoint`. + .. _tune-class-api: Trainable Class API @@ -119,32 +84,10 @@ Trainable Class API The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API: -.. code-block:: python - - from ray import tune - - class Trainable(tune.Trainable): - def setup(self, config): - # config (dict): A dict of hyperparameters - self.x = 0 - self.a = config["a"] - self.b = config["b"] - - def step(self): # This is called iteratively. - score = objective(self.x, self.a, self.b) - self.x += 1 - return {"score": score} - - tuner = tune.Tuner( - Trainable, - tune_config=air.RunConfig(stop={"training_iteration": 20}), - param_space={ - "a": 2, - "b": 4 - }) - results = tuner.fit() - - print('best config: ', results.get_best_result(metric="score", mode="max").config) +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __class_api_example_start__ + :end-before: __class_api_example_end__ As a subclass of ``tune.Trainable``, Tune will create a ``Trainable`` object on a separate process (using the :ref:`Ray Actor API `). @@ -169,20 +112,10 @@ Class API Checkpointing You can also implement checkpoint/restore using the Trainable Class API: -.. code-block:: python - - class MyTrainableClass(Trainable): - def save_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - torch.save(self.model.state_dict(), checkpoint_path) - return tmp_checkpoint_dir - - def load_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") - self.model.load_state_dict(torch.load(checkpoint_path)) - - tuner = tune.Tuner(MyTrainableClass, run_config=air.RunConfig(checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2))) - results = tuner.fit() +.. literalinclude:: /tune/doc_code/trainable.py + :language: python + :start-after: __class_api_checkpointing_start__ + :end-before: __class_api_checkpointing_end__ You can checkpoint with three different mechanisms: manually, periodically, and at termination. @@ -278,6 +211,22 @@ It is up to the user to correctly update the hyperparameters of your trainable. return True +Comparing the Function API and Class API +---------------------------------------- + +Here are a few key concepts and what they look like for the Function and Class API's. + +======================= =============================================== ============================================== +Concept Function API Class API +======================= =============================================== ============================================== +Training Iteration Increments on each `session.report` call Increments on each `Trainable.step` call +Report metrics `session.report(metrics)` Return metrics from `Trainable.step` +Saving a checkpoint `session.report(..., checkpoint=checkpoint)` `Trainable.save_checkpoint` +Loading a checkpoint `session.get_checkpoint()` `Trainable.load_checkpoint` +Accessing config Passed as an argument `def train_func(config):` Passed through `Trainable.setup` +======================= =============================================== ============================================== + + Advanced Resource Allocation ---------------------------- @@ -330,10 +279,11 @@ session (Function API) .. autofunction:: ray.air.session.get_trial_dir :noindex: +.. _tune-trainable-docstring: + tune.Trainable (Class API) -------------------------- - .. autoclass:: ray.tune.Trainable :member-order: groupwise :private-members: diff --git a/doc/source/tune/doc_code/trainable.py b/doc/source/tune/doc_code/trainable.py new file mode 100644 index 0000000000000..856d33c916631 --- /dev/null +++ b/doc/source/tune/doc_code/trainable.py @@ -0,0 +1,150 @@ +# flake8: noqa + +# __class_api_checkpointing_start__ +import os +import torch +from torch import nn + +from ray import air, tune + + +class MyTrainableClass(tune.Trainable): + def setup(self, config): + self.model = nn.Sequential( + nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10) + ) + + def step(self): + return {} + + def save_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + torch.save(self.model.state_dict(), checkpoint_path) + return tmp_checkpoint_dir + + def load_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth") + self.model.load_state_dict(torch.load(checkpoint_path)) + + +tuner = tune.Tuner( + MyTrainableClass, + param_space={"input_size": 64}, + run_config=air.RunConfig( + stop={"training_iteration": 2}, + checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2), + ), +) +tuner.fit() +# __class_api_checkpointing_end__ + +# __function_api_checkpointing_start__ +from ray import tune +from ray.air import session +from ray.air.checkpoint import Checkpoint + + +def train_func(config): + epochs = config.get("epochs", 2) + start = 0 + loaded_checkpoint = session.get_checkpoint() + if loaded_checkpoint: + last_step = loaded_checkpoint.to_dict()["step"] + start = last_step + 1 + + for step in range(start, epochs): + # Model training here + # ... + + # Report metrics and save a checkpoint + metrics = {"metric": "my_metric"} + checkpoint = Checkpoint.from_dict({"step": step}) + session.report(metrics, checkpoint=checkpoint) + + +tuner = tune.Tuner(train_func) +results = tuner.fit() +# __function_api_checkpointing_end__ + +# fmt: off +# __example_objective_start__ +def objective(x, a, b): + return a * (x ** 0.5) + b +# __example_objective_end__ +# fmt: on + +# __function_api_report_intermediate_metrics_start__ +from ray import tune +from ray.air import session + + +def trainable(config: dict): + intermediate_score = 0 + for x in range(20): + intermediate_score = objective(x, config["a"], config["b"]) + session.report({"score": intermediate_score}) # This sends the score to Tune. + + +tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4}) +results = tuner.fit() +# __function_api_report_intermediate_metrics_end__ + +# __function_api_report_final_metrics_start__ +from ray import tune +from ray.air import session + + +def trainable(config: dict): + final_score = 0 + for x in range(20): + final_score = objective(x, config["a"], config["b"]) + + session.report({"score": final_score}) # This sends the score to Tune. + + +tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4}) +results = tuner.fit() +# __function_api_report_final_metrics_end__ + +# fmt: off +# __function_api_return_final_metrics_start__ +def trainable(config: dict): + final_score = 0 + for x in range(20): + final_score = objective(x, config["a"], config["b"]) + + return {"score": final_score} # This sends the score to Tune. +# __function_api_return_final_metrics_end__ +# fmt: on + +# __class_api_example_start__ +from ray import air, tune + + +class Trainable(tune.Trainable): + def setup(self, config: dict): + # config (dict): A dict of hyperparameters + self.x = 0 + self.a = config["a"] + self.b = config["b"] + + def step(self): # This is called iteratively. + score = objective(self.x, self.a, self.b) + self.x += 1 + return {"score": score} + + +tuner = tune.Tuner( + Trainable, + run_config=air.RunConfig( + # Train for 20 steps + stop={"training_iteration": 20}, + checkpoint_config=air.CheckpointConfig( + # We haven't implemented checkpointing yet. See below! + checkpoint_at_end=False + ), + ), + param_space={"a": 2, "b": 4}, +) +results = tuner.fit() +# __class_api_example_end__ diff --git a/doc/source/tune/tutorials/tune-checkpoints.rst b/doc/source/tune/tutorials/tune-checkpoints.rst index 305da90637fe2..7b6a75760742b 100644 --- a/doc/source/tune/tutorials/tune-checkpoints.rst +++ b/doc/source/tune/tutorials/tune-checkpoints.rst @@ -31,15 +31,20 @@ Commonly, this includes the model and optimizer states. This is useful mostly fo the meantime. This only makes sense if the trials can then continue training from the latest state. - The checkpoint can be later used for other downstream tasks like batch inference. -Everything that is reported by ``session.report()`` is a trial-level checkpoint. -See :ref:`here for more information on saving checkpoints `. +Everything that is saved by ``session.report()`` (if using the Function API) or +``Trainable.save_checkpoint`` (if using the Class API) is a **trial-level checkpoint.** +See :ref:`checkpointing with the Function API ` and +:ref:`checkpointing with the Class API ` +for examples of saving and loading trial-level checkpoints. .. _tune-checkpoint-syncing: Checkpointing and synchronization --------------------------------- -This topic is mostly relevant to Trial checkpoint. +.. note:: + + This topic is relevant to trial checkpoints. Tune stores checkpoints on the node where the trials are executed. If you are training on more than one node, this means that some trial checkpoints may be on the head node and others are not. @@ -108,7 +113,9 @@ This will automatically store both the experiment state and the trial checkpoint name="experiment_name", sync_config=tune.SyncConfig( upload_dir="s3://bucket-name/sub-path/" - ))) + ) + ) + ) tuner.fit() We don't have to provide a ``syncer`` here as it will be automatically detected. However, you can provide @@ -126,7 +133,8 @@ a string if you want to use a custom command: sync_config=tune.SyncConfig( upload_dir="s3://bucket-name/sub-path/", syncer="aws s3 sync {source} {target}", # Custom sync command - )), + ) + ) ) tuner.fit() @@ -191,7 +199,8 @@ Alternatively, a function can be provided with the following signature: syncer=custom_sync_func, sync_period=60 # Synchronize more often ) - )) + ) + ) results = tuner.fit() When syncing results back to the driver, the source would be a path similar to @@ -230,11 +239,13 @@ Your ``my_trainable`` is either a: 2. **Custom training function** - * All this means is that your function needs to take care of saving and loading from checkpoint. - For saving, this is done through ``session.report()`` API, which can take in a ``Checkpoint`` object. - For loading, your function can access existing checkpoint through ``Session.get_checkpoint()`` API. - See :doc:`this example `, - it's quite simple to do. + All this means is that your function needs to take care of saving and loading from checkpoint. + + * For saving, this is done through :meth:`session.report() ` API, which can take in a ``Checkpoint`` object. + + * For loading, your function can access an existing checkpoint through the :meth:`session.get_checkpoint() ` API. + + * See :doc:`this example ` for reference. Let's assume for this example you're running this script from your laptop, and connecting to your remote Ray cluster via ``ray.init()``, making your script on your laptop the "driver". @@ -247,31 +258,29 @@ via ``ray.init()``, making your script on your laptop the "driver". ray.init(address=":") # set `address=None` to train on laptop - # configure how checkpoints are sync'd to the scheduler/sampler - # we recommend cloud storage checkpointing as it survives the cluster when - # instances are terminated, and has better performance + # Configure how checkpoints are sync'd to the scheduler/sampler + # We recommend cloud storage checkpointing as it survives the cluster when + # instances are terminated and has better performance sync_config = tune.SyncConfig( upload_dir="s3://my-checkpoints-bucket/path/", # requires AWS credentials ) - # this starts the run! + # This starts the run! tuner = tune.Tuner( my_trainable, run_config=air.RunConfig( - # name of your experiment - # if this experiment exists, we will resume from the last run - # as specified by + # Name of your experiment name="my-tune-exp", - # a directory where results are stored before being + # Directory where each node's results are stored before being # sync'd to head node/cloud storage local_dir="/tmp/mypath", - # see above! we will sync our checkpoints to S3 directory + # See above! we will sync our checkpoints to S3 directory sync_config=sync_config, checkpoint_config=air.CheckpointConfig( - # we'll keep the best five checkpoints at all times + # We'll keep the best five checkpoints at all times # checkpoints (by AUC score, reported by the trainable, descending) - checkpoint_score_attr="max-auc", - keep_checkpoints_num=5, + checkpoint_score_attribute="max-auc", + num_to_keep=5, ), ), ) @@ -281,13 +290,25 @@ In this example, checkpoints will be saved: * **Locally**: not saved! Nothing will be sync'd to the driver (your laptop) automatically (because cloud syncing is enabled) * **S3**: ``s3://my-checkpoints-bucket/path/my-tune-exp//checkpoint_`` -* **On head node**: ``~/ray-results/my-tune-exp//checkpoint_`` (but only for trials done on that node) -* **On workers nodes**: ``~/ray-results/my-tune-exp//checkpoint_`` (but only for trials done on that node) +* **On head node**: ``/tmp/mypath/my-tune-exp//checkpoint_`` (but only for trials done on that node) +* **On workers nodes**: ``/tmp/mypath/my-tune-exp//checkpoint_`` (but only for trials done on that node) + +If this run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time using experiment checkpoints saved in the cloud: + +.. code-block:: python + + from ray import tune + tuner = tune.Tuner.restore( + "s3://my-checkpoints-bucket/path/my-tune-exp", + resume_errored=True + ) + tuner.fit() + -If your run stopped for any reason (finished, errored, user CTRL+C), you can restart it any time by -``tuner=Tuner.restore(experiment_checkpoint_dir).fit()``. There are a few options for restoring an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +``resume_unfinished``, ``resume_errored`` and ``restart_errored``. +Please see the documentation of +:meth:`Tuner.restore() ` for more details. .. _rsync-checkpointing: @@ -298,7 +319,7 @@ Local or rsync checkpointing can be a good option if: 1. You want to tune on a single laptop Ray cluster 2. You aren't using Ray on Kubernetes (rsync doesn't work with Ray on Kubernetes) -3. You don't want to use S3 +3. You don't want to cloud storage (e.g. S3) Let's take a look at an example: @@ -310,32 +331,23 @@ Let's take a look at an example: ray.init(address=":") # set `address=None` to train on laptop - # configure how checkpoints are sync'd to the scheduler/sampler - sync_config = tune.syncConfig() # the default mode is to use use rsync + # Configure how checkpoints are sync'd to the scheduler/sampler + sync_config = tune.SyncConfig() # the default mode is to use use rsync - # this starts the run! + # This starts the run! tuner = tune.Tuner( my_trainable, - run_config=air.RunConfig( - # name of your experiment - # If the experiment with the same name is already run, - # Tuner willl resume from the last run specified by sync_config(if one exists). - # Otherwise, will start a new run. name="my-tune-exp", - # a directory where results are stored before being - # sync'd to head node/cloud storage local_dir="/tmp/mypath", - # sync our checkpoints via rsync - # you don't have to pass an empty sync config - but we + # Sync our checkpoints via rsync + # You don't have to pass an empty sync config - but we # do it here for clarity and comparison sync_config=sync_config, checkpoint_config=air.CheckpointConfig( - # we'll keep the best five checkpoints at all times - # checkpoints (by AUC score, reported by the trainable, descending) - checkpoint_score_attr="max-auc", - keep_checkpoints_num=5, - ) + checkpoint_score_attribute="max-auc", + num_to_keep=5, + ), ) ) diff --git a/doc/source/tune/tutorials/tune-distributed.rst b/doc/source/tune/tutorials/tune-distributed.rst index 7d1412f563cda..716f63de42dab 100644 --- a/doc/source/tune/tutorials/tune-distributed.rst +++ b/doc/source/tune/tutorials/tune-distributed.rst @@ -225,9 +225,7 @@ If the trial/actor is placed on a different node, Tune will automatically push t Recovering From Failures ~~~~~~~~~~~~~~~~~~~~~~~~ -Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through ``Tuner.restore()``. -There are a few options for restoring an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +Tune automatically persists the progress of your entire experiment (a ``Tuner.fit()`` session), so if an experiment crashes or is otherwise cancelled, it can be resumed through :meth:`Tuner.restore() `. .. _tune-distributed-common: diff --git a/doc/source/tune/tutorials/tune-stopping.rst b/doc/source/tune/tutorials/tune-stopping.rst index e23e3801501e8..8e1ee1cdd4cde 100644 --- a/doc/source/tune/tutorials/tune-stopping.rst +++ b/doc/source/tune/tutorials/tune-stopping.rst @@ -20,15 +20,16 @@ If you've stopped a run and and want to resume from where you left off, you can then call ``Tuner.restore()`` like this: .. code-block:: python - :emphasize-lines: 4 tuner = Tuner.restore( path="~/ray_results/my_experiment" ) tuner.fit() -There are a few options for resuming an experiment: -"resume_unfinished", "resume_errored" and "restart_errored". See ``Tuner.restore()`` for more details. +There are a few options for restoring an experiment: +``resume_unfinished``, ``resume_errored`` and ``restart_errored``. +Please see the documentation of +:meth:`Tuner.restore() ` for more details. ``path`` here is determined by the ``air.RunConfig.name`` you supplied to your ``Tuner()``. If you didn't supply name to ``Tuner``, it is likely that your ``path`` looks something like: @@ -48,18 +49,18 @@ of your original tuning run: Number of trials: 1/1 (1 RUNNING) What's happening under the hood? --------------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:ref:`Here `, we describe the two types of Tune checkpoints: +experiment-level and trial-level checkpoints. + +Upon resuming an interrupted/errored Tune run: -:ref:`Here ` we talked about two types of Tune checkpoints. -Both checkpoints come into play when resuming a Tune run. +#. Tune first looks at the experiment-level checkpoint to find the list of trials at the time of the interruption. -When resuming an interrupted/errored Tune run, Tune first looks at the experiment-level checkpoint -to find the list of trials at the time of the interruption. Ray Tune then locates the trial-level -checkpoint of each trial. +#. Tune then locates and restores from the trial-level checkpoint of each trial. -Depending on the specified resume option -("resume_unfinished", "resume_errored", "restart_errored"), Ray Tune then decides whether to -restore a given non-finished trial from its latest available checkpoint or start from scratch. +#. Depending on the specified resume option (``resume_unfinished``, ``resume_errored``, ``restart_errored``), Tune decides whether to restore a given unfinished trial from its latest available checkpoint or to start from scratch. .. _tune-stopping-ref: diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index cc7ef8593b3c2..2d38808d35fcc 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -1121,9 +1121,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> Optional[Union[str, Dict]]: Returns: A dict or string. If string, the return value is expected to be - prefixed by `tmp_checkpoint_dir`. If dict, the return value will - be automatically serialized by Tune and - passed to ``Trainable.load_checkpoint()``. + prefixed by `checkpoint_dir`. If dict, the return value will + be automatically serialized by Tune. In both cases, the return value + is exactly what will be passed to ``Trainable.load_checkpoint()`` + upon restore. Example: >>> trainable, trainable1, trainable2 = ... # doctest: +SKIP @@ -1152,23 +1153,35 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]): The directory structure under the checkpoint_dir provided to ``Trainable.save_checkpoint`` is preserved. - See the example below. + See the examples below. Example: + >>> import os >>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): - ... print(checkpoint_path) - ... return os.path.join(checkpoint_path, "my/check/point") - ... def load_checkpoint(self, checkpoint): - ... print(checkpoint) + ... my_checkpoint_path = os.path.join(checkpoint_path, "my/path") + ... return my_checkpoint_path + ... def load_checkpoint(self, my_checkpoint_path): + ... print(my_checkpoint_path) >>> trainer = Example() >>> # This is used when PAUSED. >>> obj = trainer.save_to_object() # doctest: +SKIP - /tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point + /tmpc8k_c_6hsave_to_object/checkpoint_0/my/path >>> # Note the different prefix. >>> trainer.restore_from_object(obj) # doctest: +SKIP - /tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point + /tmpb87b5axfrestore_from_object/checkpoint_0/my/path + + If `Trainable.save_checkpoint` returned a dict, then Tune will directly pass + the dict data as the argument to this method. + + Example: + >>> from ray.tune.trainable import Trainable + >>> class Example(Trainable): + ... def save_checkpoint(self, checkpoint_path): + ... return {"my_data": 1} + ... def load_checkpoint(self, checkpoint_dict): + ... print(checkpoint_dict["my_data"]) .. versionadded:: 0.8.7 @@ -1177,7 +1190,7 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]): returned by `save_checkpoint`. If a string, then it is a checkpoint path that may have a different prefix than that returned by `save_checkpoint`. The directory structure - underneath the `checkpoint_dir` `save_checkpoint` is preserved. + underneath the `checkpoint_dir` from `save_checkpoint` is preserved. """ raise NotImplementedError