diff --git a/python/ray/air/BUILD b/python/ray/air/BUILD index 2bd1e9164156b..cc2886d267ded 100644 --- a/python/ray/air/BUILD +++ b/python/ray/air/BUILD @@ -74,6 +74,17 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "test_experiment_restore", + size = "medium", + srcs = [ + "tests/test_experiment_restore.py", + "tests/_test_experiment_restore_run.py" + ], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test( name = "test_errors", size = "small", diff --git a/python/ray/air/tests/_test_experiment_restore_run.py b/python/ray/air/tests/_test_experiment_restore_run.py new file mode 100644 index 0000000000000..4e6daa6ac1ff3 --- /dev/null +++ b/python/ray/air/tests/_test_experiment_restore_run.py @@ -0,0 +1,183 @@ +import collections +import json +import os +from pathlib import Path +import random +import time +from typing import Dict, List, Optional + +import ray +from ray import air, tune +from ray.air import Checkpoint, session +from ray.train.data_parallel_trainer import DataParallelTrainer +from ray.tune.experiment import Trial + + +RUNNER_TYPE = os.environ.get("RUNNER_TYPE", "trainer") +STORAGE_PATH = os.environ.get("STORAGE_PATH", "/tmp/ray_results") +EXP_NAME = os.environ.get("EXP_NAME", "restore_integration_test") +CALLBACK_DUMP_FILE = os.environ.get( + "CALLBACK_DUMP_FILE", "/tmp/callback_dump_file.json" +) +CSV_DATA_FILE = os.environ.get("CSV_DATA_FILE", "/tmp/dummy.csv") + +TIME_PER_ITER_S = float(os.environ.get("TIME_PER_ITER_S", "0.5")) +NUM_TRIALS = int(os.environ.get("NUM_TRIALS", "1")) +MAX_CONCURRENT_TRIALS = int(os.environ.get("MAX_CONCURRENT_TRIALS", "2")) +ITERATIONS_PER_TRIAL = int(os.environ.get("ITERATIONS_PER_TRIAL", "64")) + + +class StatefulCallback(tune.Callback): + def __init__(self): + self._trial_iterations = collections.defaultdict(list) + + def on_trial_result( + self, + iteration: int, + trials: List["Trial"], + trial: "Trial", + result: Dict, + **info, + ): + self._trial_iterations[trial.trial_id].append(result["training_iteration"]) + + def on_experiment_end(self, trials: List["Trial"], **info): + # Save callback contents to file + with open(CALLBACK_DUMP_FILE, "w") as f: + json.dump(self.get_state(), f, indent=2) + + def get_state(self) -> Optional[Dict]: + return {"trial_iters": self._trial_iterations.copy()} + + def set_state(self, state: Dict): + self._trial_iterations = state["trial_iters"] + + +class StatefulSearcher(tune.search.Searcher): + def __init__( + self, + metric: Optional[str] = None, + mode: Optional[str] = None, + ): + super().__init__(metric=metric, mode=mode) + self._trial_count = 0 + + def suggest(self, trial_id: str) -> Optional[Dict]: + self._trial_count += 1 + return {"id": self._trial_count} + + def on_trial_complete( + self, trial_id: str, result: Optional[Dict] = None, error: bool = False + ) -> None: + pass + + def save(self, checkpoint_path: str): + with open(checkpoint_path, "w") as f: + json.dump({"trial_count": self._trial_count}, f) + + def restore(self, checkpoint_path: str): + with open(checkpoint_path, "r") as f: + state = json.load(f) + self._trial_count = state["trial_count"] + + +def train_fn(config: dict, data: Optional[dict] = None): + checkpoint = session.get_checkpoint() + start = checkpoint.to_dict()["iteration"] + 1 if checkpoint else 1 + + training_started_marker = Path( + os.environ.get("RUN_STARTED_MARKER", "/tmp/does-not-exist") + ) + if training_started_marker.exists(): + # Multiple workers may be trying to delete the same marker + try: + training_started_marker.unlink() + except FileNotFoundError: + pass + + for iteration in range(start, ITERATIONS_PER_TRIAL + 1): + time.sleep(TIME_PER_ITER_S) + + session.report( + {"score": random.random()}, + checkpoint=Checkpoint.from_dict({"iteration": iteration}), + ) + + +def tuner(experiment_path: str, run_config: air.RunConfig) -> tune.ResultGrid: + trainable = tune.with_resources(train_fn, resources={"CPU": 1}) + trainable = tune.with_parameters(trainable, data={"dummy_data": [1, 2, 3]}) + + if tune.Tuner.can_restore(experiment_path): + tuner = tune.Tuner.restore( + experiment_path, trainable=trainable, resume_errored=True + ) + else: + tuner = tune.Tuner( + trainable, + run_config=run_config, + tune_config=tune.TuneConfig( + num_samples=8, + max_concurrent_trials=2, + search_alg=StatefulSearcher(), + ), + ) + + result_grid = tuner.fit() + return result_grid + + +def trainer(experiment_path: str, run_config: air.RunConfig) -> air.Result: + dataset_size = 128 + num_workers = 4 + + def train_loop_per_worker(config): + # Wrap the other train_fn with a check for the dataset. + assert session.get_dataset_shard("train") + train_fn(config) + + datasets = { + "train": ray.data.range(dataset_size), + "valid": ray.data.read_csv(CSV_DATA_FILE), + } + + if DataParallelTrainer.can_restore(experiment_path): + trainer = DataParallelTrainer.restore( + experiment_path, + datasets=datasets, + train_loop_per_worker=train_loop_per_worker, + ) + else: + trainer = DataParallelTrainer( + train_loop_per_worker, + datasets=datasets, + scaling_config=air.ScalingConfig( + num_workers=num_workers, trainer_resources={"CPU": 0} + ), + run_config=run_config, + ) + + result = trainer.fit() + return result + + +if __name__ == "__main__": + experiment_path = os.path.join(STORAGE_PATH, EXP_NAME) + + ray.init() + + run_config = air.RunConfig( + storage_path=STORAGE_PATH, + name=EXP_NAME, + checkpoint_config=air.CheckpointConfig(num_to_keep=1), + callbacks=[StatefulCallback()], + ) + + if RUNNER_TYPE == "tuner": + tuner(experiment_path, run_config) + elif RUNNER_TYPE == "trainer": + trainer(experiment_path, run_config) + else: + raise NotImplementedError( + "`RUNNER_TYPE` environment var must be one of ['tuner', 'trainer']" + ) diff --git a/python/ray/air/tests/test_experiment_restore.py b/python/ray/air/tests/test_experiment_restore.py new file mode 100644 index 0000000000000..54ef313c6ee25 --- /dev/null +++ b/python/ray/air/tests/test_experiment_restore.py @@ -0,0 +1,249 @@ +import json +import numpy as np +import pandas as pd +from pathlib import Path +import pytest +import time +import shutil +import signal +import subprocess +import sys + +from ray.tune.result_grid import ResultGrid +from ray.tune.analysis import ExperimentAnalysis + + +_RUN_SCRIPT_FILENAME = "_test_experiment_restore_run.py" + + +def _kill_process_if_needed( + process: subprocess.Popen, timeout_s: float = 10, poll_interval_s: float = 1.0 +): + """Kills a process if it hasn't finished in `timeout_s` seconds. + Polls every `poll_interval_s` seconds to check if the process is still running.""" + kill_timeout = time.monotonic() + timeout_s + while process.poll() is None and time.monotonic() < kill_timeout: + time.sleep(poll_interval_s) + if process.poll() is None: + process.terminate() + + +def _print_message(message): + sep = "=" * 50 + print(f"\n{sep}\n{message}\n{sep}\n") + + +@pytest.mark.parametrize("runner_type", ["tuner", "trainer"]) +def test_experiment_restore(tmp_path, runner_type): + """ + This is an integration stress test for experiment restoration. + + + Test setup: + + - For Tuner.restore: + - 8 trials, with a max of 2 running concurrently (--> 4 rounds of trials) + - Each iteration takes 0.5 seconds + - Each trial runs for 8 iterations --> 4 seconds + - Each round of 2 trials should take 4 seconds + - Without any interrupts/restoration: + - Minimum runtime: 4 rounds * 4 seconds / round = 16 seconds + - The test will stop the script with a SIGINT at a random time between + 4-8 iterations each restore. + + - For Trainer.restore: + - 1 trial with 4 workers + - Each iteration takes 0.5 seconds + - Runs for 32 iterations --> Minimum runtime = 16 seconds + - The test will stop the script with a SIGINT at a random time between + 4-8 iterations after each restore. + + + Requirements: + - Req 1: Reasonable runtime + - The experiment should finish within 2 * 16 = 32 seconds. + - 2x is the passing threshold. + - 16 seconds is the minimum runtime. + - Req 2: Training progress persisted + - The experiment should progress monotonically. + (The training iteration shouldn't go backward at any point) + - Trials shouldn't start from scratch. + - Req 3: Searcher state saved/restored correctly + - Req 4: Callback state saved/restored correctly + """ + + np.random.seed(2023) + + script_path = Path(__file__).parent / _RUN_SCRIPT_FILENAME + + # Args to pass into the script as environment variables + exp_name = f"{runner_type}_restore_integration_test" + callback_dump_file = tmp_path / f"{runner_type}-callback_dump_file.json" + storage_path = tmp_path / "ray_results" + if storage_path.exists(): + shutil.rmtree(storage_path) + + csv_file = str(tmp_path / "dummy_data.csv") + dummy_df = pd.DataFrame({"x": np.arange(128), "y": 2 * np.arange(128)}) + dummy_df.to_csv(csv_file) + + run_started_marker = tmp_path / "run_started_marker" + + time_per_iter_s = 0.5 + max_concurrent = 2 + + if runner_type == "tuner": + iters_per_trial = 8 + num_trials = 8 + elif runner_type == "trainer": + iters_per_trial = 32 + num_trials = 1 + + total_iters = iters_per_trial * num_trials + + env = { + "RUNNER_TYPE": runner_type, + "STORAGE_PATH": str(storage_path), + "EXP_NAME": exp_name, + "CALLBACK_DUMP_FILE": str(callback_dump_file), + "RUN_STARTED_MARKER": str(run_started_marker), + "TIME_PER_ITER_S": str(time_per_iter_s), + "ITERATIONS_PER_TRIAL": str(iters_per_trial), + "NUM_TRIALS": str(num_trials), + "MAX_CONCURRENT_TRIALS": str(max_concurrent), + "CSV_DATA_FILE": csv_file, + } + + # Pass criteria + no_interrupts_runtime = 16.0 + passing_factor = 2 + passing_runtime = no_interrupts_runtime * passing_factor + _print_message( + "Experiment should finish with a total runtime of\n" + f"<= {passing_runtime} seconds." + ) + + # Variables used in the loop + return_code = None + total_runtime = 0 + run_iter = 0 + progress_history = [] + + poll_interval_s = 0.1 + test_start_time = time.monotonic() + + while total_runtime < passing_runtime: + run_started_marker.write_text("", encoding="utf-8") + + run = subprocess.Popen([sys.executable, script_path], env=env) + run_iter += 1 + + _print_message(f"Started run #{run_iter} w/ PID = {run.pid}") + + # Start the timer after the first trial has entered its training loop. + while run.poll() is None and run_started_marker.exists(): + time.sleep(poll_interval_s) + + # If the run already finished, then exit immediately. + if run.poll() is not None: + return_code = run.poll() + break + + timeout_s = min( + np.random.uniform(4 * time_per_iter_s, 8 * time_per_iter_s), + passing_runtime - total_runtime, + ) + + _print_message( + "Training has started...\n" + f"Interrupting after {timeout_s:.2f} seconds\n" + f"Currently at {total_runtime:.2f}/{passing_runtime} seconds" + ) + + # Sleep for a random amount of time, then stop the run. + start_time = time.monotonic() + stopping_time = start_time + timeout_s + while time.monotonic() < stopping_time: + time.sleep(poll_interval_s) + total_runtime += time.monotonic() - start_time + + return_code = run.poll() + if return_code is None: + # Send "SIGINT" to stop the run + _print_message(f"Sending SIGUSR1 to run #{run_iter} w/ PID = {run.pid}") + run.send_signal(signal.SIGUSR1) + + # Make sure the process is stopped forcefully after a timeout. + _kill_process_if_needed(run) + else: + _print_message("Run has already terminated!") + break + + # Check up on the results. + results = ResultGrid(ExperimentAnalysis(str(storage_path / exp_name))) + iters = [result.metrics.get("training_iteration", 0) for result in results] + progress = sum(iters) / total_iters + progress_history.append(progress) + _print_message( + f"Number of trials = {len(results)}\n" + f"% completion = {progress} ({sum(iters)} iters / {total_iters})\n" + f"Currently at {total_runtime:.2f}/{passing_runtime} seconds" + ) + + _print_message( + f"Total number of restorations = {run_iter}\n" + f"Total runtime = {total_runtime:.2f}\n" + f"Return code = {return_code}" + ) + test_end_time = time.monotonic() + + # The script shouldn't have errored. (It should have finished by this point.) + assert return_code == 0, ( + f"The script errored with return code: {return_code}.\n" + f"Check the `{_RUN_SCRIPT_FILENAME}` script for any issues." + ) + + # Req 1: runtime + assert ( + total_runtime <= passing_runtime + ), f"Expected runtime to be <= {passing_runtime}, but ran for: {total_runtime}" + + # Req 2: training progress persisted + # Check that progress increases monotonically (we never go backwards/start from 0) + assert np.all(np.diff(progress_history) >= 0), ( + "Expected progress to increase monotonically. Instead, got:\n" + "{progress_history}" + ) + + # Req 3: searcher state + results = ResultGrid(ExperimentAnalysis(str(storage_path / exp_name))) + # Check that all trials have unique ids assigned by the searcher (if applicable) + ids = [result.config.get("id", -1) for result in results] + ids = [id for id in ids if id >= 0] + if ids: + assert sorted(ids) == list(range(1, num_trials + 1)), ( + "Expected the searcher to assign increasing id for each trial, but got:" + f"{ids}" + ) + + # Req 4: callback state + with open(callback_dump_file, "r") as f: + callback_state = json.load(f) + + trial_iters = callback_state["trial_iters"] + for iters in trial_iters.values(): + # Check that the callback has data for each trial, for all iters + # NOTE: There may be some duplicate data, due to the fact that + # the callback will be updated on every `on_trial_result` hook, + # but the trial may crash before the corresponding checkpoint gets processed. + assert sorted(set(iters)) == list( + range(1, iters_per_trial + 1) + ), f"Expected data from all iterations, but got: {iters}" + + _print_message(f"Success! Test took {test_end_time - test_start_time:.2f} seconds.") + + +if __name__ == "__main__": + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 2f925b0f18306..595dd4646c6c4 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -293,19 +293,16 @@ def training_loop(self): assert trainer_state_path.exists() with open(trainer_state_path, "rb") as fp: - original_trainer = pickle.load(fp) - if type(original_trainer) is not cls: + trainer_cls, param_dict = pickle.load(fp) + if trainer_cls is not cls: warnings.warn( f"Invalid trainer type. You are attempting to restore a trainer of type" - f" {type(original_trainer)} with `{cls.__name__}.restore`, " + f" {trainer_cls} with `{cls.__name__}.restore`, " "which will most likely fail. " - f"Use `{type(original_trainer).__name__}.restore` instead." + f"Use `{trainer_cls.__name__}.restore` instead." ) - # Get the param dict used to initialize the original trainer - param_dict = original_trainer._param_dict - - original_datasets = original_trainer.datasets or {} + original_datasets = param_dict.pop("datasets", {}) if original_datasets and not datasets: raise ValueError( "The following datasets need to be provided again on restore: " @@ -617,16 +614,37 @@ def fit(self) -> Result: return result def _save(self, experiment_path: Union[str, Path]): - """Saves the trainer to a directory. - - This is used to populate a newly constructed trainer on restore. - Unless a parameter is re-specified during restoration (only a limited - set of parameters can be passed in again), the argument will be loaded - from this saved one. + """Saves the current trainer's class along with the `param_dict` of + parameters passed to this trainer's constructor. + + This is used to recreate the trainer on restore. + Unless a parameter is re-specified during restoration (only a subset + of parameters can be passed in again), that parameter will be loaded + from the saved copy. + + Ray Datasets should not be saved as part of the state. Instead, we save the + keys and replace the dataset values with dummy functions that will + raise an error if invoked. The error only serves as a guardrail for + misuse (e.g., manually unpickling and constructing the Trainer again) + and is not typically surfaced, since datasets must be re-specified + upon restoration. """ + param_dict = self._param_dict.copy() + datasets = param_dict.pop("datasets", {}) + + def raise_fn(): + raise RuntimeError + + if datasets: + param_dict["datasets"] = { + dataset_name: raise_fn for dataset_name in datasets + } + + cls_and_param_dict = (self.__class__, param_dict) + experiment_path = Path(experiment_path) with open(experiment_path / _TRAINER_PKL, "wb") as fp: - pickle.dump(self, fp) + pickle.dump(cls_and_param_dict, fp) def _extract_fields_for_tuner_param_space(self) -> Dict: """Extracts fields to be included in `Tuner.param_space`. diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index da071f1c7db36..20d57bac93a07 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -393,7 +393,7 @@ def get_state(self) -> Optional[Dict]: def set_state(self, state: Dict): """Sets the state for all callbacks contained within this list. - Skipps setting state for all stateless callbacks where `get_state` + Skips setting state for all stateless callbacks where `get_state` returned None.""" for i, callback in enumerate(self._callbacks): callback_state = state.get(i, None)