Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] [air/tune/train] Improve trial/training failure error printing #27946

Merged
merged 18 commits into from
Aug 26, 2022
13 changes: 13 additions & 0 deletions doc/source/ray-air/user-guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,16 @@ AIR User Guides
:type: ref
:text: How to Deploy AIR
:classes: btn-link btn-block stretched-link

.. _air-env-vars:

Environment variables
---------------------

Some behavior of Ray AIR can be controlled using environment variables.

Please also see the :ref:`Ray Tune environment variables <tune-env-vars>`.

- **RAY_AIR_FULL_TRACEBACKS**: If set to 1, will print full tracebacks for training functions,
including internal code paths. Otherwise, abbreviated tracebacks that only show user code
are printed. Defaults to 0 (disabled).
2 changes: 2 additions & 0 deletions doc/source/tune/api_docs/env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ These are the environment variables Ray Tune currently considers:
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
running a large number of short trials. Defaults to every ``5`` (seconds).
* **TUNE_PRINT_ALL_TRIAL_ERRORS**: If ``1``, will print all trial errors as they come up. Otherwise, errors
will only be saved as text files to the trial directory and not printed. Defaults to ``1``.
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
is not set, ``~/ray_results`` will be used.
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
Expand Down
8 changes: 8 additions & 0 deletions python/ray/air/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ py_test(
deps = [":ml_lib"]
)

py_test(
name = "test_tracebacks",
size = "small",
srcs = ["tests/test_tracebacks.py"],
tags = ["team:ml", "exclusive"],
deps = [":ml_lib"]
)

# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
Expand Down
27 changes: 27 additions & 0 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import socket
from contextlib import closing
from typing import Optional

import numpy as np

Expand All @@ -17,3 +19,28 @@ def is_nan(value):

def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)


class StartTraceback(Exception):
"""These exceptions (and their tracebacks) can be skipped with `skip_exceptions`"""

pass


def skip_exceptions(exc: Optional[Exception]) -> Exception:
"""Skip all contained `StartTracebacks` to reduce traceback output"""
should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))
krfricke marked this conversation as resolved.
Show resolved Hide resolved

if should_not_shorten:
return exc

if isinstance(exc, StartTraceback):
# If this is a StartTraceback, skip
return skip_exceptions(exc.__cause__)

# Else, make sure nested exceptions are properly skipped
cause = getattr(exc, "__cause__", None)
if cause:
exc.__cause__ = skip_exceptions(cause)

return exc
72 changes: 72 additions & 0 deletions python/ray/air/tests/test_tracebacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

import ray
from ray.air import ScalingConfig
from ray.air._internal.util import StartTraceback, skip_exceptions
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.tune import Tuner


@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


def _failing_recursive(levels: int = 0, start_traceback: int = -1):
if levels > 0:
if start_traceback == 0:
try:
_failing_recursive(
levels=levels - 1, start_traceback=start_traceback - 1
)
except Exception as e:
raise StartTraceback from e
else:
_failing_recursive(levels=levels - 1, start_traceback=start_traceback - 1)
else:
raise RuntimeError("Failing")


@pytest.mark.parametrize("levels", [4, 5, 6, 7, 8, 9, 10])
def test_short_traceback(levels):
start_traceback = 3
with pytest.raises(StartTraceback) as exc_info:
_failing_recursive(levels=levels, start_traceback=start_traceback)

exc = skip_exceptions(exc_info.value)
tb = exc.__traceback__
i = 0
while tb:
i += 1
tb = tb.tb_next

assert i == levels - start_traceback + 1


def test_traceback_tuner(ray_start_2_cpus):
def failing(config):
raise RuntimeError("Error")

tuner = Tuner(failing)
results = tuner.fit()
assert len(str(results[0].error).split("\n")) <= 10


def test_traceback_trainer(ray_start_2_cpus):
def failing(config):
raise RuntimeError("Error")

trainer = DataParallelTrainer(failing, scaling_config=ScalingConfig(num_workers=1))
with pytest.raises(RuntimeError) as exc_info:
trainer.fit()
assert len(str(exc_info.value).split("\n")) <= 13


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-x", __file__]))
11 changes: 7 additions & 4 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
if self._max_failures < 0:
self._max_failures = float("inf")
self._num_failures = 0
self._last_failure = None
self._initialization_hook = None
self._placement_group = None

Expand Down Expand Up @@ -474,10 +475,11 @@ def get_with_failure_handling(self, remote_values):
Returns:
The resolved objects represented by the passed in ObjectRefs.
"""
success = check_for_failure(remote_values)
success, exception = check_for_failure(remote_values)
if success:
return ray.get(remote_values)
else:
self._last_failure = exception
krfricke marked this conversation as resolved.
Show resolved Hide resolved
self._increment_failures()
logger.warning(
"Failure identified during training. Restarting all workers and "
Expand Down Expand Up @@ -521,13 +523,14 @@ def _restart(self):
def _increment_failures(self):
self._num_failures += 1
if self._num_failures >= self._max_failures:
raise RuntimeError(
"Training has failed even after "
exc = RuntimeError(
"Training has failed after "
f"{self._num_failures} "
"attempts. You can change the number of max "
"failure attempts by setting the "
"`max_retries` arg in your `Trainer`."
) from None
)
raise exc.with_traceback(None) from self._last_failure

def get_worker_group(self):
return self.worker_group
Expand Down
39 changes: 31 additions & 8 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

import ray
from ray.air._internal.util import find_free_port
from ray.air._internal.util import find_free_port, StartTraceback, skip_exceptions
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.types import ObjectRef
Expand All @@ -29,13 +29,16 @@
logger = logging.getLogger(__name__)


def check_for_failure(remote_values: List[ObjectRef]) -> bool:
def check_for_failure(
remote_values: List[ObjectRef],
) -> Tuple[bool, Optional[Exception]]:
"""Check for actor failure when retrieving the remote values.

Args:
remote_values: List of object references from Ray actor methods.

Returns:
A tuple of (bool, Exception). The bool is
True if evaluating all object references is successful, False otherwise.
"""
unfinished = remote_values.copy()
Expand All @@ -51,12 +54,14 @@ def check_for_failure(remote_values: List[ObjectRef]) -> bool:
try:
ray.get(object_ref)
except RayActorError as exc:
logger.exception(str(exc))
failed_actor_rank = remote_values.index(object_ref)
logger.info(f"Worker {failed_actor_rank} has failed.")
return False
return False, exc
except Exception as exc:
# Other (e.g. training) errors should be directly raised
raise StartTraceback from skip_exceptions(exc)

return True
return True, None


def get_address_and_port() -> Tuple[str, int]:
Expand Down Expand Up @@ -138,7 +143,10 @@ def construct_train_func(
# Those returns are inaccesible with AIR anyway.
@functools.wraps(train_func)
def discard_return_wrapper(*args, **kwargs):
train_func(*args, **kwargs)
try:
train_func(*args, **kwargs)
except Exception as e:
raise StartTraceback from e

wrapped_train_func = discard_return_wrapper
else:
Expand All @@ -152,9 +160,24 @@ def discard_return_wrapper(*args, **kwargs):
raise ValueError(err_msg)
elif num_params == 1:
config = {} if config is None else config
return lambda: wrapped_train_func(config)

@functools.wraps(wrapped_train_func)
def train_fn():
try:
return wrapped_train_func(config)
except Exception as e:
raise StartTraceback from e

else: # num_params == 0
return wrapped_train_func

@functools.wraps(wrapped_train_func)
def train_fn():
try:
return wrapped_train_func()
except Exception as e:
raise StartTraceback from e

return train_fn


class Singleton(abc.ABCMeta):
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/_internal/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ray
from ray.actor import ActorHandle
from ray.air._internal.util import skip_exceptions
from ray.types import ObjectRef
from ray.util.placement_group import PlacementGroup

Expand All @@ -23,7 +24,10 @@ def __execute(self, func: Callable[..., T], *args, **kwargs) -> T:
func: The function to execute.
args, kwargs: The arguments to pass into func.
"""
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as e:
raise skip_exceptions(e) from None


@dataclass
Expand Down
13 changes: 9 additions & 4 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ray
import ray.train as train
from ray.air._internal.util import StartTraceback
from ray.cluster_utils import Cluster

# Trigger pytest hook to automatically zip test cluster logs to archive dir on failure
Expand Down Expand Up @@ -171,19 +172,23 @@ def test_train_failure(ray_start_2_cpus):
e = BackendExecutor(config, num_workers=2)
e.start()

with pytest.raises(TrainBackendError):
with pytest.raises(StartTraceback) as exc:
e.get_next_results()
assert isinstance(exc.value.__cause__, TrainBackendError)

with pytest.raises(TrainBackendError):
with pytest.raises(StartTraceback) as exc:
e.pause_reporting()
assert isinstance(exc.value.__cause__, TrainBackendError)

with pytest.raises(TrainBackendError):
with pytest.raises(StartTraceback) as exc:
e.finish_training()
assert isinstance(exc.value.__cause__, TrainBackendError)

e.start_training(lambda: 1, dataset_spec=EMPTY_RAY_DATASET_SPEC)

with pytest.raises(TrainBackendError):
with pytest.raises(StartTraceback) as exc:
e.start_training(lambda: 2, dataset_spec=EMPTY_RAY_DATASET_SPEC)
assert isinstance(exc.value.__cause__, TrainBackendError)

assert e.finish_training() == [1, 1]

Expand Down
Loading