diff --git a/python/ray/tune/experiment/trial.py b/python/ray/tune/experiment/trial.py index 468d41672824e..2787989527df7 100644 --- a/python/ray/tune/experiment/trial.py +++ b/python/ray/tune/experiment/trial.py @@ -451,6 +451,17 @@ def last_result(self) -> dict: def last_result(self, val: dict): self._last_result = val + def get_runner_ip(self) -> Optional[str]: + if self.location.hostname: + return self.location.hostname + + if not self.runner: + return None + + hostname, pid = ray.get(self.runner.get_current_ip_pid.remote()) + self.location = _Location(hostname, pid) + return self.location.hostname + @property def logdir(self): if not self.relative_logdir: diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index 3a18bdd34b3e1..a8fb1026999e9 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -27,7 +27,6 @@ ) from ray.tune import TuneError from ray.tune.callback import Callback -from ray.tune.result import NODE_IP from ray.tune.utils.file_transfer import sync_dir_between_nodes from ray.util.annotations import PublicAPI, DeveloperAPI from ray.widgets import Template @@ -500,6 +499,7 @@ def __init__(self, enabled: bool = True, sync_period: float = DEFAULT_SYNC_PERIO self._sync_processes: Dict[str, _BackgroundProcess] = {} self._sync_times: Dict[str, float] = {} self._sync_period = sync_period + self._trial_ips = {} def _get_trial_sync_process(self, trial: "Trial"): return self._sync_processes.setdefault( @@ -537,10 +537,16 @@ def _sync_trial_dir( if not force and (not self._should_sync(trial) or sync_process.is_running): return False - if NODE_IP in trial.last_result: - source_ip = trial.last_result[NODE_IP] - else: - source_ip = ray.get(trial.runner.get_current_ip.remote()) + source_ip = self._trial_ips.get(trial.trial_id, None) + + if not source_ip: + source_ip = trial.get_runner_ip() + + # If it still does not exist, the runner is terminated. + if not source_ip: + return False + + self._trial_ips[trial.trial_id] = source_ip try: sync_process.wait() @@ -571,6 +577,11 @@ def _sync_trial_dir( ) return True + def on_trial_start( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self._trial_ips.pop(trial.trial_id, None) + def on_trial_result( self, iteration: int, @@ -586,6 +597,13 @@ def on_trial_complete( ): self._sync_trial_dir(trial, force=True, wait=True) self._remove_trial_sync_process(trial) + self._trial_ips.pop(trial.trial_id, None) + + def on_trial_error( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self._remove_trial_sync_process(trial) + self._trial_ips.pop(trial.trial_id, None) def on_checkpoint( self, @@ -622,3 +640,9 @@ def wait_for_all(self): f"At least one trial failed to sync down when waiting for all " f"trials to sync: \n{sync_str}" ) + + def __getstate__(self): + state = self.__dict__.copy() + for remove in ["_sync_times", "_sync_processes", "_trial_ips"]: + state.pop(remove, None) + return state diff --git a/python/ray/tune/tests/test_syncer_callback.py b/python/ray/tune/tests/test_syncer_callback.py index ffeb2c637f721..582171ef44f55 100644 --- a/python/ray/tune/tests/test_syncer_callback.py +++ b/python/ray/tune/tests/test_syncer_callback.py @@ -11,7 +11,6 @@ from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint from ray.tune import TuneError from ray.tune.logger import NoopLogger -from ray.tune.result import NODE_IP from ray.tune.syncer import ( DEFAULT_SYNC_PERIOD, SyncConfig, @@ -72,11 +71,14 @@ def assert_file(exists: bool, root: str, path: str): class MockTrial: def __init__(self, trial_id: str, logdir: str): self.trial_id = trial_id - self.last_result = {NODE_IP: ray.util.get_node_ip_address()} self.uses_cloud_checkpointing = False self.sync_on_checkpoint = True self.logdir = logdir + self._local_ip = ray.util.get_node_ip_address() + + def get_runner_ip(self): + return self._local_ip class TestSyncerCallback(SyncerCallback): @@ -211,6 +213,29 @@ def test_syncer_callback_sync(ray_start_2_cpus, temp_data_dirs): assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt") +def test_syncer_callback_sync_with_invalid_ip(ray_start_2_cpus, temp_data_dirs): + """Check that the sync client updates the IP correctly""" + tmp_source, tmp_target = temp_data_dirs + + syncer_callback = TestSyncerCallback(local_logdir_override=tmp_target) + + trial1 = MockTrial(trial_id="a", logdir=tmp_source) + + syncer_callback._trial_ips[trial1.trial_id] = "invalid" + syncer_callback.on_trial_start(iteration=0, trials=[], trial=trial1) + + syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={}) + syncer_callback.wait_for_all() + + assert_file(True, tmp_target, "level0.txt") + assert_file(True, tmp_target, "level0_exclude.txt") + assert_file(True, tmp_target, "subdir/level1.txt") + assert_file(True, tmp_target, "subdir/level1_exclude.txt") + assert_file(True, tmp_target, "subdir/nested/level2.txt") + assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt") + assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt") + + def test_syncer_callback_no_size_limit(temp_data_dirs): """Check if max_size_bytes is set to None for sync function""" tmp_source, _ = temp_data_dirs diff --git a/python/ray/tune/trainable/trainable.py b/python/ray/tune/trainable/trainable.py index aedf56ba4acb3..4f03dc0db6e97 100644 --- a/python/ray/tune/trainable/trainable.py +++ b/python/ray/tune/trainable/trainable.py @@ -156,7 +156,7 @@ def __init__( self._stderr_file = stderr_file start_time = time.time() - self._local_ip = self.get_current_ip() + self._local_ip = ray.util.get_node_ip_address() self.setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time if setup_time > SETUP_TIME_THRESHOLD: @@ -219,9 +219,8 @@ def resource_help(cls, config: Dict): """ return "" - def get_current_ip(self): - self._local_ip = ray.util.get_node_ip_address() - return self._local_ip + def get_current_ip_pid(self): + return self._local_ip, os.getpid() def get_auto_filled_metrics( self, @@ -689,7 +688,7 @@ def restore( self._restored = True logger.info( - "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_dir + "Restored on %s from checkpoint: %s", self._local_ip, checkpoint_dir ) state = { "_iteration": self._iteration,