Skip to content

Commit

Permalink
Clean up a few checkpoint related things. (ray-project#35321)
Browse files Browse the repository at this point in the history
1. Get rid of the distributed metadata field that is not useful anymore.
2. Print a message when we download the checkpoint from cloud storage.
3. As promised, tune.run() should just take checkpoint_config.

Signed-off-by: Jun Gong <jungong@anyscale.com>
Signed-off-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
2 people authored and arvind-chandra committed Aug 31, 2023
1 parent 13cfa58 commit 890ec37
Show file tree
Hide file tree
Showing 22 changed files with 234 additions and 123 deletions.
15 changes: 13 additions & 2 deletions doc/source/ray-air/examples/rl_offline_example.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "57fe8246",
"metadata": {},
Expand All @@ -12,6 +13,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "edc8d8ac",
"metadata": {},
Expand All @@ -30,6 +32,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "503b1b55",
"metadata": {},
Expand Down Expand Up @@ -271,7 +274,7 @@
"import numpy as np\n",
"import ray\n",
"from ray.air import Checkpoint\n",
"from ray.air.config import RunConfig\n",
"from ray.air.config import CheckpointConfig, RunConfig\n",
"from ray.train.rl.rl_predictor import RLPredictor\n",
"from ray.train.rl.rl_trainer import RLTrainer\n",
"from ray.air.config import ScalingConfig\n",
Expand All @@ -281,6 +284,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "184fe936",
"metadata": {},
Expand Down Expand Up @@ -318,6 +322,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8bca906c",
"metadata": {},
Expand Down Expand Up @@ -357,13 +362,16 @@
" # result = trainer.fit()\n",
" tuner = Tuner(\n",
" trainer,\n",
" _tuner_kwargs={\"checkpoint_at_end\": True},\n",
" run_config=RunConfig(\n",
" checkpoint_config=CheckpointConfig(checkpoint_at_end=True)\n",
" ),\n",
" )\n",
" result = tuner.fit()[0]\n",
" return result"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d935cdee",
"metadata": {},
Expand Down Expand Up @@ -398,6 +406,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "84f4bebe",
"metadata": {},
Expand Down Expand Up @@ -938,6 +947,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c7534d5c",
"metadata": {},
Expand Down Expand Up @@ -1395,6 +1405,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "71d7f318",
"metadata": {},
Expand Down
13 changes: 11 additions & 2 deletions doc/source/ray-air/examples/rl_online_example.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "3471e19a",
"metadata": {},
Expand All @@ -12,6 +13,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f5083f08",
"metadata": {},
Expand All @@ -30,6 +32,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "980cea70",
"metadata": {},
Expand Down Expand Up @@ -60,7 +63,7 @@
"import numpy as np\n",
"import ray\n",
"from ray.air import Checkpoint\n",
"from ray.air.config import RunConfig\n",
"from ray.air.config import CheckpointConfig, RunConfig\n",
"from ray.train.rl.rl_predictor import RLPredictor\n",
"from ray.train.rl.rl_trainer import RLTrainer\n",
"from ray.air.config import ScalingConfig\n",
Expand All @@ -70,6 +73,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "a13db7e4",
"metadata": {},
Expand Down Expand Up @@ -99,13 +103,16 @@
" # result = trainer.fit()\n",
" tuner = Tuner(\n",
" trainer,\n",
" _tuner_kwargs={\"checkpoint_at_end\": True},\n",
" run_config=RunConfig(\n",
" checkpoint_config=CheckpointConfig(checkpoint_at_end=True)\n",
" ),\n",
" )\n",
" result = tuner.fit()[0]\n",
" return result"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f7a5d5c2",
"metadata": {},
Expand Down Expand Up @@ -140,6 +147,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d226d6aa",
"metadata": {},
Expand Down Expand Up @@ -1294,6 +1302,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6714a3d6",
"metadata": {},
Expand Down
16 changes: 14 additions & 2 deletions doc/source/ray-air/examples/rl_serving_example.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "c7f3b22d",
"metadata": {},
Expand All @@ -14,6 +15,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "95f0f98f",
"metadata": {},
Expand All @@ -32,6 +34,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b1be42d5",
"metadata": {},
Expand All @@ -51,7 +54,7 @@
"import requests\n",
"\n",
"from ray.air.checkpoint import Checkpoint\n",
"from ray.air.config import RunConfig\n",
"from ray.air.config import CheckpointConfig, RunConfig\n",
"from ray.train.rl.rl_trainer import RLTrainer\n",
"from ray.air.config import ScalingConfig\n",
"from ray.train.rl.rl_predictor import RLPredictor\n",
Expand All @@ -62,6 +65,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "144e26c8",
"metadata": {},
Expand Down Expand Up @@ -91,13 +95,16 @@
" # result = trainer.fit()\n",
" tuner = Tuner(\n",
" trainer,\n",
" _tuner_kwargs={\"checkpoint_at_end\": True},\n",
" run_config=RunConfig(\n",
" checkpoint_config=CheckpointConfig(checkpoint_at_end=True)\n",
" ),\n",
" )\n",
" result = tuner.fit()[0]\n",
" return result"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "528b8175",
"metadata": {},
Expand Down Expand Up @@ -127,6 +134,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "145a9bef",
"metadata": {},
Expand Down Expand Up @@ -174,6 +182,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b45c6f71",
"metadata": {},
Expand Down Expand Up @@ -1356,6 +1365,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "edda9c2b",
"metadata": {},
Expand Down Expand Up @@ -1395,6 +1405,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "53af304d",
"metadata": {},
Expand Down Expand Up @@ -3808,6 +3819,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "55ec8c8e",
"metadata": {},
Expand Down
3 changes: 3 additions & 0 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,9 @@ def _to_directory(self, path: str, move_instead_of_copy: bool = False) -> None:
else:
_copy_dir_ignore_conflicts(local_path_pathlib, path_pathlib)
elif external_path:
logger.info(
f"Downloading checkpoint from {external_path} to {path} ..."
)
# If this exists on external storage (e.g. cloud), download
download_from_uri(uri=external_path, local_path=path, filelock=False)
else:
Expand Down
8 changes: 4 additions & 4 deletions python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,11 @@ class CheckpointConfig:

num_to_keep: Optional[int] = None
checkpoint_score_attribute: Optional[str] = None
checkpoint_score_order: str = MAX
checkpoint_frequency: int = 0
checkpoint_score_order: Optional[str] = MAX
checkpoint_frequency: Optional[int] = 0
checkpoint_at_end: Optional[bool] = None
_checkpoint_keep_all_ranks: bool = False
_checkpoint_upload_from_workers: bool = False
_checkpoint_keep_all_ranks: Optional[bool] = False
_checkpoint_upload_from_workers: Optional[bool] = False

def __post_init__(self):
if self.num_to_keep is not None and self.num_to_keep <= 0:
Expand Down
30 changes: 30 additions & 0 deletions python/ray/air/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,36 @@ def test_checkpointing_config():
assert checkpointing._tune_legacy_checkpoint_score_attr == "min-metric"


def test_checkpointing_config_deprecated():
def resolve(checkpoint_score_attr):
# Copied from tune.tun()
checkpoint_config = CheckpointConfig()

if checkpoint_score_attr.startswith("min-"):
checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr[4:]
checkpoint_config.checkpoint_score_order = "min"
else:
checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr
checkpoint_config.checkpoint_score_order = "max"

return checkpoint_config

cc = resolve("loss")
assert cc._tune_legacy_checkpoint_score_attr == "loss"
assert cc.checkpoint_score_attribute == "loss"
assert cc.checkpoint_score_order == "max"

cc = resolve("min-loss")
assert cc._tune_legacy_checkpoint_score_attr == "min-loss"
assert cc.checkpoint_score_attribute == "loss"
assert cc.checkpoint_score_order == "min"

cc = resolve("min-min-loss")
assert cc._tune_legacy_checkpoint_score_attr == "min-min-loss"
assert cc.checkpoint_score_attribute == "min-loss"
assert cc.checkpoint_score_order == "min"


def test_scaling_config():
with pytest.raises(ValueError):
DummyTrainer(scaling_config="invalid")
Expand Down
8 changes: 1 addition & 7 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.accelerator import Accelerator
from ray.train.constants import (
CHECKPOINT_DISTRIBUTED_KEY,
CHECKPOINT_METADATA_KEY,
CHECKPOINT_RANK_KEY,
DETAILED_AUTOFILLED_KEYS,
Expand Down Expand Up @@ -396,12 +395,7 @@ def checkpoint(self, checkpoint: Checkpoint):
checkpoint = str(checkpoint._local_path)

# Save the rank of the worker that created this checkpoint.
metadata.update(
{
CHECKPOINT_RANK_KEY: self.world_rank,
CHECKPOINT_DISTRIBUTED_KEY: upload_from_workers,
}
)
metadata.update({CHECKPOINT_RANK_KEY: self.world_rank})

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
Expand Down
4 changes: 0 additions & 4 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,3 @@

# Key for AIR Checkpoint world rank in TrainingResult metadata
CHECKPOINT_RANK_KEY = "checkpoint_rank"


# Key for AIR Checkpoint that gets uploaded from distributed workers.
CHECKPOINT_DISTRIBUTED_KEY = "distributed"
7 changes: 5 additions & 2 deletions python/ray/tune/examples/pbt_transformers/pbt_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os

from ray import tune
from ray.air.config import CheckpointConfig
from ray.tune import CLIReporter
from ray.tune.examples.pbt_transformers.utils import (
download_data,
Expand Down Expand Up @@ -135,8 +136,10 @@ def get_model():
n_trials=num_samples,
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
scheduler=scheduler,
keep_checkpoints_num=1,
checkpoint_score_attr="training_iteration",
checkpoint_config=CheckpointConfig(
num_to_keep=1,
checkpoint_score_attribute="training_iteration",
),
stop={"training_iteration": 1} if smoke_test else None,
progress_reporter=reporter,
local_dir="~/ray_results/",
Expand Down
Loading

0 comments on commit 890ec37

Please sign in to comment.