Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Fix trial checkpoint syncing after recovery from other node #28470

Merged
merged 4 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,17 @@ def last_result(self) -> dict:
def last_result(self, val: dict):
self._last_result = val

def get_runner_ip(self) -> Optional[str]:
if self.location.hostname:
return self.location.hostname

if not self.runner:
return None

hostname, pid = ray.get(self.runner.get_current_ip_pid.remote())
self.location = _Location(hostname, pid)
return self.location.hostname

@property
def logdir(self):
if not self.relative_logdir:
Expand Down
34 changes: 29 additions & 5 deletions python/ray/tune/syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from ray.tune import TuneError
from ray.tune.callback import Callback
from ray.tune.result import NODE_IP
from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.widgets import Template
Expand Down Expand Up @@ -500,6 +499,7 @@ def __init__(self, enabled: bool = True, sync_period: float = DEFAULT_SYNC_PERIO
self._sync_processes: Dict[str, _BackgroundProcess] = {}
self._sync_times: Dict[str, float] = {}
self._sync_period = sync_period
self._trial_ips = {}

def _get_trial_sync_process(self, trial: "Trial"):
return self._sync_processes.setdefault(
Expand Down Expand Up @@ -537,10 +537,16 @@ def _sync_trial_dir(
if not force and (not self._should_sync(trial) or sync_process.is_running):
return False

if NODE_IP in trial.last_result:
source_ip = trial.last_result[NODE_IP]
else:
source_ip = ray.get(trial.runner.get_current_ip.remote())
source_ip = self._trial_ips.get(trial.trial_id, None)

if not source_ip:
source_ip = trial.get_runner_ip()

# If it still does not exist, the runner is terminated.
if not source_ip:
return False

self._trial_ips[trial.trial_id] = source_ip

try:
sync_process.wait()
Expand Down Expand Up @@ -571,6 +577,11 @@ def _sync_trial_dir(
)
return True

def on_trial_start(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self._trial_ips.pop(trial.trial_id, None)

def on_trial_result(
self,
iteration: int,
Expand All @@ -586,6 +597,13 @@ def on_trial_complete(
):
self._sync_trial_dir(trial, force=True, wait=True)
self._remove_trial_sync_process(trial)
self._trial_ips.pop(trial.trial_id, None)

def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self._remove_trial_sync_process(trial)
self._trial_ips.pop(trial.trial_id, None)

def on_checkpoint(
self,
Expand Down Expand Up @@ -622,3 +640,9 @@ def wait_for_all(self):
f"At least one trial failed to sync down when waiting for all "
f"trials to sync: \n{sync_str}"
)

def __getstate__(self):
state = self.__dict__.copy()
for remove in ["_sync_times", "_sync_processes", "_trial_ips"]:
state.pop(remove, None)
return state
29 changes: 27 additions & 2 deletions python/ray/tune/tests/test_syncer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.tune import TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import NODE_IP
from ray.tune.syncer import (
DEFAULT_SYNC_PERIOD,
SyncConfig,
Expand Down Expand Up @@ -72,11 +71,14 @@ def assert_file(exists: bool, root: str, path: str):
class MockTrial:
def __init__(self, trial_id: str, logdir: str):
self.trial_id = trial_id
self.last_result = {NODE_IP: ray.util.get_node_ip_address()}
self.uses_cloud_checkpointing = False
self.sync_on_checkpoint = True

self.logdir = logdir
self._local_ip = ray.util.get_node_ip_address()

def get_runner_ip(self):
return self._local_ip


class TestSyncerCallback(SyncerCallback):
Expand Down Expand Up @@ -211,6 +213,29 @@ def test_syncer_callback_sync(ray_start_2_cpus, temp_data_dirs):
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")


def test_syncer_callback_sync_with_invalid_ip(ray_start_2_cpus, temp_data_dirs):
"""Check that the sync client updates the IP correctly"""
tmp_source, tmp_target = temp_data_dirs

syncer_callback = TestSyncerCallback(local_logdir_override=tmp_target)

trial1 = MockTrial(trial_id="a", logdir=tmp_source)

syncer_callback._trial_ips[trial1.trial_id] = "invalid"
syncer_callback.on_trial_start(iteration=0, trials=[], trial=trial1)

syncer_callback.on_trial_result(iteration=1, trials=[], trial=trial1, result={})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused as to why would we encounter this?
if a result is reported, shouldn't there always exist a NODE_IP field, which is automatically filled in by us?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the last_result could be populated from a restored checkpoint and point to an IP address that does not exist anymore

syncer_callback.wait_for_all()

assert_file(True, tmp_target, "level0.txt")
assert_file(True, tmp_target, "level0_exclude.txt")
assert_file(True, tmp_target, "subdir/level1.txt")
assert_file(True, tmp_target, "subdir/level1_exclude.txt")
assert_file(True, tmp_target, "subdir/nested/level2.txt")
assert_file(True, tmp_target, "subdir_nested_level2_exclude.txt")
assert_file(True, tmp_target, "subdir_exclude/something/somewhere.txt")


def test_syncer_callback_no_size_limit(temp_data_dirs):
"""Check if max_size_bytes is set to None for sync function"""
tmp_source, _ = temp_data_dirs
Expand Down
9 changes: 4 additions & 5 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
self._stderr_file = stderr_file

start_time = time.time()
self._local_ip = self.get_current_ip()
self._local_ip = ray.util.get_node_ip_address()
self.setup(copy.deepcopy(self.config))
setup_time = time.time() - start_time
if setup_time > SETUP_TIME_THRESHOLD:
Expand Down Expand Up @@ -219,9 +219,8 @@ def resource_help(cls, config: Dict):
"""
return ""

def get_current_ip(self):
self._local_ip = ray.util.get_node_ip_address()
return self._local_ip
def get_current_ip_pid(self):
return self._local_ip, os.getpid()

def get_auto_filled_metrics(
self,
Expand Down Expand Up @@ -689,7 +688,7 @@ def restore(
self._restored = True

logger.info(
"Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_dir
"Restored on %s from checkpoint: %s", self._local_ip, checkpoint_dir
)
state = {
"_iteration": self._iteration,
Expand Down