diff --git a/python/ray/tune/execution/ray_trial_executor.py b/python/ray/tune/execution/ray_trial_executor.py index fe6f886e304ad..a0eec9aeee7a0 100644 --- a/python/ray/tune/execution/ray_trial_executor.py +++ b/python/ray/tune/execution/ray_trial_executor.py @@ -217,7 +217,7 @@ def __init__( self._has_cleaned_up_pgs = False self._reuse_actors = reuse_actors - # The maxlen will be updated when `set_max_pending_trials()` is called + # The maxlen will be updated when `setup(max_pending_trials)` is called self._cached_actor_pg = deque(maxlen=1) self._pg_manager = _PlacementGroupManager(prefix=_get_tune_pg_prefix()) self._staged_trials = set() @@ -235,16 +235,20 @@ def __init__( self._buffer_max_time_s = float( os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0) ) + self._trainable_kwargs = {} - def set_max_pending_trials(self, max_pending: int) -> None: + def setup( + self, max_pending_trials: int, trainable_kwargs: Optional[Dict] = None + ) -> None: if len(self._cached_actor_pg) > 0: logger.warning( "Cannot update maximum number of queued actors for reuse " "during a run." ) else: - self._cached_actor_pg = deque(maxlen=max_pending) - self._pg_manager.set_max_staging(max_pending) + self._cached_actor_pg = deque(maxlen=max_pending_trials) + self._pg_manager.set_max_staging(max_pending_trials) + self._trainable_kwargs = trainable_kwargs or {} def set_status(self, trial: Trial, status: str) -> None: """Sets status and checkpoints metadata if needed. @@ -377,6 +381,9 @@ def _setup_remote_runner(self, trial): kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir kwargs["custom_syncer"] = trial.custom_syncer + if self._trainable_kwargs: + kwargs.update(self._trainable_kwargs) + # Throw a meaningful error if trainable does not use the # new API sig = inspect.signature(trial.get_trainable_cls()) diff --git a/python/ray/tune/execution/trial_runner.py b/python/ray/tune/execution/trial_runner.py index d810729f747d8..a4fd37b00208d 100644 --- a/python/ray/tune/execution/trial_runner.py +++ b/python/ray/tune/execution/trial_runner.py @@ -198,6 +198,8 @@ def _serialize_and_write(): exclude = ["*/checkpoint_*"] if self._syncer: + # Todo: Implement sync_timeout for experiment-level syncing + # (it is currently only used for trainable-to-cloud syncing) if force: # Wait until previous sync command finished self._syncer.wait() @@ -341,7 +343,13 @@ def __init__( else: # Manual override self._max_pending_trials = int(max_pending_trials) - self.trial_executor.set_max_pending_trials(self._max_pending_trials) + + sync_config = sync_config or SyncConfig() + + self.trial_executor.setup( + max_pending_trials=self._max_pending_trials, + trainable_kwargs={"sync_timeout": sync_config.sync_timeout}, + ) self._metric = metric @@ -385,7 +393,6 @@ def __init__( if self._local_checkpoint_dir: os.makedirs(self._local_checkpoint_dir, exist_ok=True) - sync_config = sync_config or SyncConfig() self._remote_checkpoint_dir = remote_checkpoint_dir self._syncer = get_node_to_storage_syncer(sync_config) diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index a5ade393834bd..3a18bdd34b3e1 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -40,6 +40,9 @@ # Syncing period for syncing checkpoints between nodes or to cloud. DEFAULT_SYNC_PERIOD = 300 +# Default sync timeout after which syncing processes are aborted +DEFAULT_SYNC_TIMEOUT = 1800 + _EXCLUDE_FROM_SYNC = [ "./checkpoint_-00001", "./checkpoint_tmp*", @@ -85,6 +88,8 @@ class SyncConfig: is asynchronous and best-effort. This does not affect persistent storage syncing. Defaults to True. sync_period: Syncing period for syncing between nodes. + sync_timeout: Timeout after which running sync processes are aborted. + Currently only affects trial-to-cloud syncing. """ @@ -93,6 +98,7 @@ class SyncConfig: sync_on_checkpoint: bool = True sync_period: int = DEFAULT_SYNC_PERIOD + sync_timeout: int = DEFAULT_SYNC_TIMEOUT def _repr_html_(self) -> str: """Generate an HTML representation of the SyncConfig. diff --git a/python/ray/tune/tests/test_ray_trial_executor.py b/python/ray/tune/tests/test_ray_trial_executor.py index 4615f605b2568..8e18c5476fbe9 100644 --- a/python/ray/tune/tests/test_ray_trial_executor.py +++ b/python/ray/tune/tests/test_ray_trial_executor.py @@ -499,7 +499,7 @@ def testHasResourcesForTrialWithCaching(self): executor = RayTrialExecutor(reuse_actors=True) executor._pg_manager = pgm - executor.set_max_pending_trials(1) + executor.setup(max_pending_trials=1) def train(config): yield 1 diff --git a/python/ray/tune/tests/test_trainable.py b/python/ray/tune/tests/test_trainable.py index 248152285f6a9..254edb5c2fdb7 100644 --- a/python/ray/tune/tests/test_trainable.py +++ b/python/ray/tune/tests/test_trainable.py @@ -1,14 +1,16 @@ import json import os import tempfile +import time from typing import Dict, Union +from unittest.mock import patch import pytest import ray from ray import tune from ray.air import session, Checkpoint -from ray.air._internal.remote_storage import download_from_uri +from ray.air._internal.remote_storage import download_from_uri, upload_to_uri from ray.tune.trainable import wrap_function @@ -188,6 +190,42 @@ def test_checkpoint_object_no_sync(tmpdir): trainable.restore_from_object(obj) +@pytest.mark.parametrize("hanging", [True, False]) +def test_sync_timeout(tmpdir, hanging): + orig_upload_fn = upload_to_uri + + def _hanging_upload(*args, **kwargs): + time.sleep(200 if hanging else 0) + orig_upload_fn(*args, **kwargs) + + trainable = SavingTrainable( + "object", + remote_checkpoint_dir=f"memory:///test/location_hanging_{hanging}", + sync_timeout=0.5, + ) + + with patch("ray.air.checkpoint.upload_to_uri", _hanging_upload): + trainable.save() + + check_dir = tmpdir / "check_save_obj" + + try: + download_from_uri( + uri=f"memory:///test/location_hanging_{hanging}", local_path=str(check_dir) + ) + except FileNotFoundError: + hung = True + else: + hung = False + + assert hung == hanging + + if hanging: + assert not check_dir.exists() + else: + assert check_dir.listdir() + + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_utils.py b/python/ray/tune/tests/test_utils.py index bca055f4c8c74..21b99c6873216 100644 --- a/python/ray/tune/tests/test_utils.py +++ b/python/ray/tune/tests/test_utils.py @@ -1,45 +1,94 @@ -import unittest +import time + +import pytest from ray.tune.search.variant_generator import format_vars +from ray.tune.utils.util import retry_fn + +def test_format_vars(): -class TuneUtilsTest(unittest.TestCase): - def testFormatVars(self): - # Format brackets correctly - self.assertTrue( - format_vars( - { - ("a", "b", "c"): 8.1234567, - ("a", "b", "d"): [7, 8], - ("a", "b", "e"): [[[3, 4]]], - } - ), - "c=8.12345,d=7_8,e=3_4", + # Format brackets correctly + assert ( + format_vars( + { + ("a", "b", "c"): 8.1234567, + ("a", "b", "d"): [7, 8], + ("a", "b", "e"): [[[3, 4]]], + } ) - # Sorted by full keys, but only last key is reported - self.assertTrue( - format_vars( - { - ("a", "c", "x"): [7, 8], - ("a", "b", "x"): 8.1234567, - } - ), - "x=8.12345,x=7_8", + == "c=8.1235,d=7_8,e=3_4" + ) + # Sorted by full keys, but only last key is reported + assert ( + format_vars( + { + ("a", "c", "x"): [7, 8], + ("a", "b", "x"): 8.1234567, + } ) - # Filter out invalid chars. It's ok to have empty keys or values. - self.assertTrue( - format_vars( - { - ("a c?x"): " <;%$ok ", - ("some"): " ", - } - ), - "a_c_x=ok,some=", + == "x=8.1235,x=7_8" + ) + # Filter out invalid chars. It's ok to have empty keys or values. + assert ( + format_vars( + { + ("a c?x",): " <;%$ok ", + ("some",): " ", + } ) + == "a_c_x=ok,some=" + ) + + +def test_retry_fn_repeat(tmpdir): + success = tmpdir / "success" + marker = tmpdir / "marker" + + def _fail_once(): + if marker.exists(): + success.write_text(".", encoding="utf-8") + return + marker.write_text(".", encoding="utf-8") + raise RuntimeError("Failing") + + assert not success.exists() + assert not marker.exists() + + assert retry_fn( + fn=_fail_once, + exception_type=RuntimeError, + sleep_time=0, + ) + + assert success.exists() + assert marker.exists() + + +def test_retry_fn_timeout(tmpdir): + success = tmpdir / "success" + marker = tmpdir / "marker" + + def _fail_once(): + if not marker.exists(): + marker.write_text(".", encoding="utf-8") + raise RuntimeError("Failing") + time.sleep(5) + success.write_text(".", encoding="utf-8") + return + + assert not success.exists() + assert not marker.exists() + + assert not retry_fn( + fn=_fail_once, exception_type=RuntimeError, sleep_time=0, timeout=0.1 + ) + + assert not success.exists() + assert marker.exists() if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index a63b767c585c0..aedf56ba4acb3 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -101,8 +101,9 @@ def __init__( logger_creator: Callable[[Dict[str, Any]], "Logger"] = None, remote_checkpoint_dir: Optional[str] = None, custom_syncer: Optional[Syncer] = None, + sync_timeout: Optional[int] = None, ): - """Initialize an Trainable. + """Initialize a Trainable. Sets up logging and points ``self.logdir`` to a directory in which training outputs should be placed. @@ -120,6 +121,7 @@ def __init__( which is different from **per checkpoint** directory. custom_syncer: Syncer used for synchronizing data from Ray nodes to external storage. + sync_timeout: Timeout after which sync processes are aborted. """ self._experiment_id = uuid.uuid4().hex @@ -171,6 +173,7 @@ def __init__( self.remote_checkpoint_dir = remote_checkpoint_dir self.custom_syncer = custom_syncer + self.sync_timeout = sync_timeout @property def uses_cloud_checkpointing(self): @@ -512,12 +515,22 @@ def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool: return True checkpoint = Checkpoint.from_directory(checkpoint_dir) - retry_fn( - lambda: checkpoint.to_uri(self._storage_path(checkpoint_dir)), + checkpoint_uri = self._storage_path(checkpoint_dir) + if not retry_fn( + lambda: checkpoint.to_uri(checkpoint_uri), subprocess.CalledProcessError, num_retries=3, sleep_time=1, - ) + timeout=self.sync_timeout, + ): + logger.error( + f"Could not upload checkpoint even after 3 retries." + f"Please check if the credentials expired and that the remote " + f"filesystem is supported.. For large checkpoints, consider " + f"increasing `SyncConfig(sync_timeout)` " + f"(current value: {self.sync_timeout} seconds). Checkpoint URI: " + f"{checkpoint_uri}" + ) return True def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool: @@ -546,12 +559,17 @@ def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool: return True checkpoint = Checkpoint.from_uri(external_uri) - retry_fn( + if not retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, - ) + timeout=self.sync_timeout, + ): + logger.error( + f"Could not download checkpoint even after 3 retries: " + f"{external_uri}" + ) return True @@ -719,12 +737,17 @@ def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]): self.custom_syncer.wait_or_retry() else: checkpoint_uri = self._storage_path(checkpoint_dir) - retry_fn( + if not retry_fn( lambda: _delete_external_checkpoint(checkpoint_uri), subprocess.CalledProcessError, num_retries=3, sleep_time=1, - ) + timeout=self.sync_timeout, + ): + logger.error( + f"Could not delete checkpoint even after 3 retries: " + f"{checkpoint_uri}" + ) if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 6deaa3a57f757..66b581b339fa8 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -7,6 +7,7 @@ import time from collections import defaultdict from datetime import datetime +from numbers import Number from threading import Thread from typing import Dict, List, Union, Type, Callable, Any, Optional @@ -124,18 +125,41 @@ def stop(self): @DeveloperAPI def retry_fn( fn: Callable[[], Any], - exception_type: Type[Exception], + exception_type: Type[Exception] = Exception, num_retries: int = 3, sleep_time: int = 1, -): - for i in range(num_retries): + timeout: Optional[Number] = None, +) -> bool: + errored = threading.Event() + + def _try_fn(): try: fn() except exception_type as e: logger.warning(e) - time.sleep(sleep_time) - else: - break + errored.set() + + for i in range(num_retries): + errored.clear() + + proc = threading.Thread(target=_try_fn) + proc.daemon = True + proc.start() + proc.join(timeout=timeout) + + if proc.is_alive(): + logger.debug( + f"Process timed out (try {i+1}/{num_retries}): " + f"{getattr(fn, '__name__', None)}" + ) + elif not errored.is_set(): + return True + + # Timed out, sleep and try again + time.sleep(sleep_time) + + # Timed out, so return False + return False @ray.remote