Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Add timeout to retry_fn to catch hanging syncs #28155

Merged
merged 13 commits into from
Sep 2, 2022
110 changes: 78 additions & 32 deletions python/ray/tune/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,91 @@
import unittest
import pytest

from ray.tune.search.variant_generator import format_vars
from ray.tune.utils.util import retry_fn


class TuneUtilsTest(unittest.TestCase):
def testFormatVars(self):
# Format brackets correctly
self.assertTrue(
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
),
"c=8.12345,d=7_8,e=3_4",
def test_format_vars():

# Format brackets correctly
assert (
format_vars(
{
("a", "b", "c"): 8.1234567,
("a", "b", "d"): [7, 8],
("a", "b", "e"): [[[3, 4]]],
}
)
# Sorted by full keys, but only last key is reported
self.assertTrue(
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
),
"x=8.12345,x=7_8",
== "c=8.12345,d=7_8,e=3_4"
)
# Sorted by full keys, but only last key is reported
assert (
format_vars(
{
("a", "c", "x"): [7, 8],
("a", "b", "x"): 8.1234567,
}
)
# Filter out invalid chars. It's ok to have empty keys or values.
self.assertTrue(
format_vars(
{
("a c?x"): " <;%$ok ",
("some"): " ",
}
),
"a_c_x=ok,some=",
== "x=8.12345,x=7_8"
)
# Filter out invalid chars. It's ok to have empty keys or values.
assert (
format_vars(
{
("a c?x"): " <;%$ok ",
("some"): " ",
}
)
== "a_c_x=ok,some="
)


def test_retry_fn_repeat(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"

def _fail_once():
if marker.exists():
success.write_text(".", encoding="utf-8")
return
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")

assert not success.exists()
assert not marker.exists()

assert retry_fn(
fn=_fail_once,
exception_type=RuntimeError,
sleep_time=0,
)

assert success.exists()
assert marker.exists()


def test_retry_fn_timeout(tmpdir):
success = tmpdir / "success"
marker = tmpdir / "marker"

def _fail_once():
if marker.exists():
success.write_text(".", encoding="utf-8")
return
marker.write_text(".", encoding="utf-8")
raise RuntimeError("Failing")

assert not success.exists()
assert not marker.exists()

assert not retry_fn(
fn=_fail_once, exception_type=RuntimeError, sleep_time=5, timeout=0.1
)

assert not success.exists()
assert marker.exists()


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
11 changes: 11 additions & 0 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
SETUP_TIME_THRESHOLD = 10


def _sync_timeout() -> Optional[float]:
sync_timeout = float(os.environ.get("TUNE_SYNC_TIMEOUT", "600"))
krfricke marked this conversation as resolved.
Show resolved Hide resolved
if sync_timeout == 0:
return None

return sync_timeout


@PublicAPI
class Trainable:
"""Abstract class for trainable models, functions, etc.
Expand Down Expand Up @@ -517,6 +525,7 @@ def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool:
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
timeout=_sync_timeout(),
)
return True

Expand Down Expand Up @@ -551,6 +560,7 @@ def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
timeout=_sync_timeout(),
)

return True
Expand Down Expand Up @@ -724,6 +734,7 @@ def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
subprocess.CalledProcessError,
num_retries=3,
sleep_time=1,
timeout=_sync_timeout(),
)

if os.path.exists(checkpoint_dir):
Expand Down
34 changes: 24 additions & 10 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import glob
import inspect
import logging
import multiprocessing
import os
import threading
import time
from collections import defaultdict
from datetime import datetime
from numbers import Number
from threading import Thread
from typing import Dict, List, Union, Type, Callable, Any, Optional

Expand Down Expand Up @@ -124,18 +126,30 @@ def stop(self):
@DeveloperAPI
def retry_fn(
fn: Callable[[], Any],
exception_type: Type[Exception],
exception_type: Type[Exception] = Exception,
num_retries: int = 3,
sleep_time: int = 1,
):
for i in range(num_retries):
try:
fn()
except exception_type as e:
logger.warning(e)
time.sleep(sleep_time)
else:
break
timeout: Optional[Number] = None,
) -> bool:
def _retry_fn():
krfricke marked this conversation as resolved.
Show resolved Hide resolved
for i in range(num_retries):
try:
fn()
except exception_type as e:
logger.warning(e)
time.sleep(sleep_time)
else:
return

proc = multiprocessing.Process(target=_retry_fn)
proc.start()
proc.join(timeout=timeout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that you have a thread, imagine eventually we checkpoint on the side while the training just keeps going 🤯 😄

one nit, I also think timeout should be per-retry? (so timeout=num_retries * timeout here). otherwise the actual timeout will be dependent on how many retries you set here? although, admittedly, num_retries is not even a configurable bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I thought about this, but the reason why I kept a global timeout was because it is a) simpler/cleaner to implement and b) we basically want to define a maximum time we want to block training, so I think we should be fine with this. Let me know if you prefer this per-retry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually as discussed, let's have a timeout per retry. Otherwise if the first sync hangs we will not try again. Updated the PR


if proc.exitcode is None:
proc.terminate()
return False

return proc.exitcode == 0


@ray.remote
Expand Down