Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Immediately fail if application errors on any worker #28314

Merged
merged 9 commits into from
Sep 21, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
immediate error
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
amogkam committed Sep 6, 2022
commit 7b01da060ec1e4b14f2e45ff3775d95aac2e8a4f
45 changes: 45 additions & 0 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
import socket
from contextlib import closing
import logging
import queue
import threading
from typing import Optional

import numpy as np

from ray.air.constants import _ERROR_REPORT_TIMEOUT

logger = logging.getLogger(__name__)


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
@@ -44,3 +51,41 @@ def skip_exceptions(exc: Optional[Exception]) -> Exception:
exc.__cause__ = skip_exceptions(cause)

return exc


class RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""

def __init__(self, *args, error_queue, **kwargs):
threading.Thread.__init__(self, *args, **kwargs)
self._error_queue = error_queue
self._ret = None

def run(self):
try:
self._ret = self._target(*self._args, **self._kwargs)
except StopIteration:
logger.debug(
(
"Thread runner raised StopIteration. Interpreting it as a "
"signal to terminate the thread without error."
)
)
except Exception as e:
try:
# report the error but avoid indefinite blocking which would
# prevent the exception from being propagated in the unlikely
# case that something went terribly wrong
self._error_queue.put(e, block=True, timeout=_ERROR_REPORT_TIMEOUT)
except queue.Full:
logger.critical(
(
"Runner Thread was unable to report error to main "
"function runner thread. This means a previous error "
"was not processed. This should never happen."
)
)

def join(self, timeout=None):
super(RunnerThread, self).join(timeout)
return self._ret
3 changes: 3 additions & 0 deletions python/ray/air/constants.py
Original file line number Diff line number Diff line change
@@ -22,3 +22,6 @@
# The maximum length of strings returned by `__repr__` for AIR objects constructed with
# default values.
MAX_REPR_LENGTH = int(80 * 1.5)

# Timeout used when putting exceptions raised by runner thread into the queue.
_ERROR_REPORT_TIMEOUT = 10
39 changes: 36 additions & 3 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging
import platform
import queue
import threading
@@ -9,17 +10,18 @@
from typing import Callable, Dict, Optional, Type, Union

import ray
from ray.air._internal.util import StartTraceback, skip_exceptions, RunnerThread
from ray.air.checkpoint import Checkpoint
from ray.data import Dataset, DatasetPipeline
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.utils import PropagatingThread
from ray.train.constants import (
DATE,
DETAILED_AUTOFILLED_KEYS,
HOSTNAME,
NODE_IP,
PID,
RESULT_FETCH_TIMEOUT,
ERROR_FETCH_TIMEOUT,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
TIMESTAMP,
@@ -34,6 +36,9 @@ class TrainingResultType(Enum):
CHECKPOINT = auto()


logger = logging.getLogger(__name__)


@dataclass
class TrialInfo:
"""The trial information to propagate to TrainSession."""
@@ -71,8 +76,6 @@ def __init__(

self.dataset_shard = dataset_shard

# The Thread object that is running the training function.
self.training_thread = PropagatingThread(target=training_func, daemon=True)
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
@@ -102,6 +105,16 @@ def noop(x):
# Queue for sending results across threads.
self.result_queue = queue.Queue(1)

# Queue for raising exceptions from runner thread to main thread.
# The error queue has a max size of one to prevent stacking error and force
# error reporting to block until finished.
self.error_queue = queue.Queue(1)

# The Thread object that is running the training function.
self.training_thread = RunnerThread(
target=training_func, daemon=True, error_queue=self.error_queue
)

# Autofilled metrics attributes.
self.detailed_autofilled_metrics = detailed_autofilled_metrics
self.last_report_time = time.time()
@@ -168,6 +181,19 @@ def get_next(self) -> Optional[TrainingResult]:
except queue.Empty:
pass

# check if error occurred inside the thread runner.
if result is None:
# only raise an error from the runner if all results are consumed
self._report_thread_runner_error(block=True)
else:
if not self.error_queue.empty():
logger.debug(
(
"Runner error waiting to be raised in main thread. "
"Logging all available results first."
)
)

# Release the lock to trigger training to continue.
self.continue_lock.release()

@@ -235,6 +261,13 @@ def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
result.update(auto_filled_metrics)
return result

def _report_thread_runner_error(self, block=False):
try:
e = self.error_queue.get(block=block, timeout=ERROR_FETCH_TIMEOUT)
raise StartTraceback from skip_exceptions(e)
except queue.Empty:
pass

def checkpoint(self, **kwargs):
"""Adds kwargs to the queue to be consumed by main thread.
18 changes: 0 additions & 18 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import os
import logging
from pathlib import Path
from threading import Thread

from typing import (
Tuple,
@@ -87,23 +86,6 @@ def construct_path(path: Path, parent_path: Path) -> Path:
return parent_path.joinpath(path).expanduser().resolve()


class PropagatingThread(Thread):
"""A Thread subclass that stores exceptions and results."""

def run(self):
self.exc = None
try:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e

def join(self, timeout=None):
super(PropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret


def update_env_vars(env_vars: Dict[str, Any]):
"""Updates the environment variables on this worker process.
4 changes: 4 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,10 @@
# new results after signaling the training function to continue.
RESULT_FETCH_TIMEOUT = 0.2

# Time between Session.get_next checks for fetching exceptions raised by the training
# function.
ERROR_FETCH_TIMEOUT = 1

# Default filename for JSON logger
RESULT_FILE_JSON = "results.json"

21 changes: 21 additions & 0 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,12 @@
from unittest.mock import patch

import pytest
import time

import ray
import ray.train as train
from ray.air._internal.util import StartTraceback
from ray.air import session
from ray.cluster_utils import Cluster

# Trigger pytest hook to automatically zip test cluster logs to archive dir on failure
@@ -193,6 +195,25 @@ def test_train_failure(ray_start_2_cpus):
assert e.finish_training() == [1, 1]


def test_train_single_worker_failure(ray_start_2_cpus):
"""Tests if training fails immediately if only one worker raises an Exception."""
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
e.start()

def single_worker_fail():
if session.get_world_rank() == 0:
raise ValueError
else:
time.sleep(1000000)

e.start_training(single_worker_fail, dataset_spec=EMPTY_RAY_DATASET_SPEC)

with pytest.raises(StartTraceback) as exc:
e.get_next_results()
assert isinstance(exc.value.__cause__, ValueError)


def test_worker_failure(ray_start_2_cpus):
config = TestConfig()
e = BackendExecutor(config, num_workers=2)
18 changes: 15 additions & 3 deletions python/ray/train/tests/test_session.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import pytest

import ray
from ray.air._internal.util import StartTraceback
from ray.train._internal.accelerator import Accelerator
from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train._internal.session import (
@@ -107,9 +108,8 @@ def train_func():
init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
assert session.get_next() is None
with pytest.raises(TypeError):
session.finish()
with pytest.raises(StartTraceback):
session.get_next()
shutdown_session()


@@ -310,6 +310,18 @@ def test_set_accelerator_raises_error_outside_session():
set_accelerator(accelerator)


def test_application_error_raised():
def f():
raise ValueError

init_session(training_func=f, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
with pytest.raises(StartTraceback):
session.get_next()
shutdown_session()


if __name__ == "__main__":
import pytest
import sys
43 changes: 4 additions & 39 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type, Union

from ray.air._internal.util import StartTraceback, skip_exceptions
from ray.air._internal.util import StartTraceback, skip_exceptions, RunnerThread
from ray.tune.resources import Resources
from six.moves import queue

@@ -40,7 +40,6 @@
# new results after signaling the reporter to continue
RESULT_FETCH_TIMEOUT = 0.2

ERROR_REPORT_TIMEOUT = 10
ERROR_FETCH_TIMEOUT = 1

NULL_MARKER = ".null_marker"
@@ -276,42 +275,6 @@ def trial_resources(self):
return self._trial_resources


class _RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""

def __init__(self, entrypoint, error_queue):
threading.Thread.__init__(self)
self._entrypoint = entrypoint
self._error_queue = error_queue
self.daemon = True

def run(self):
try:
self._entrypoint()
except StopIteration:
logger.debug(
(
"Thread runner raised StopIteration. Interpreting it as a "
"signal to terminate the thread without error."
)
)
except Exception as e:
logger.error("Runner Thread raised error")
try:
# report the error but avoid indefinite blocking which would
# prevent the exception from being propagated in the unlikely
# case that something went terribly wrong
self._error_queue.put(e, block=True, timeout=ERROR_REPORT_TIMEOUT)
except queue.Full:
logger.critical(
(
"Runner Thread was unable to report error to main "
"function runner thread. This means a previous error "
"was not processed. This should never happen."
)
)


@DeveloperAPI
class FunctionTrainable(Trainable):
"""Trainable that runs a user function reporting results.
@@ -370,7 +333,9 @@ def entrypoint():
raise StartTraceback from e

# the runner thread is not started until the first call to _train
self._runner = _RunnerThread(entrypoint, self._error_queue)
self._runner = RunnerThread(
target=entrypoint, error_queue=self._error_queue, daemon=True
)
# if not alive, try to start
self._status_reporter._start()
try: