Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] [air/tune/train] Improve trial/training failure error printing #27946

Merged
merged 18 commits into from
Aug 26, 2022
2 changes: 2 additions & 0 deletions doc/source/tune/api_docs/env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ These are the environment variables Ray Tune currently considers:
In normal circumstances these shouldn't differ anyway, but reconcilation makes sure to capture cases when
placement groups are manually destroyed. Reconcilation doesn't take much time, but it can add up when
running a large number of short trials. Defaults to every ``5`` (seconds).
* **TUNE_PRINT_ALL_TRIAL_ERRORS**: If ``1``, will print all trial errors as they come up. Otherwise, errors
will only be saved as text files to the trial directory and not printed. Defaults to ``1``.
* **TUNE_RESULT_DIR**: Directory where Ray Tune trial results are stored. If this
is not set, ``~/ray_results`` will be used.
* **TUNE_RESULT_BUFFER_LENGTH**: Ray Tune can buffer results from trainables before they are passed
Expand Down
20 changes: 20 additions & 0 deletions python/ray/air/_internal/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket
from contextlib import closing

Expand All @@ -17,3 +18,22 @@ def is_nan(value):

def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)


def shorten_tb(tb, attr: str):
should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))
krfricke marked this conversation as resolved.
Show resolved Hide resolved

if should_not_shorten:
return tb

orig_tb = tb
while tb:
if tb.tb_frame.f_locals.get(attr):
if tb.tb_next:
# If there is another `attr` later downstream, use that instead
return shorten_tb(tb.tb_next, attr=attr)
return tb

tb = tb.tb_next

return orig_tb
9 changes: 6 additions & 3 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
if self._max_failures < 0:
self._max_failures = float("inf")
self._num_failures = 0
self._last_failure = None
self._initialization_hook = None
self._placement_group = None

Expand Down Expand Up @@ -474,10 +475,11 @@ def get_with_failure_handling(self, remote_values):
Returns:
The resolved objects represented by the passed in ObjectRefs.
"""
success = check_for_failure(remote_values)
success, exception = check_for_failure(remote_values)
if success:
return ray.get(remote_values)
else:
self._last_failure = exception
krfricke marked this conversation as resolved.
Show resolved Hide resolved
self._increment_failures()
logger.warning(
"Failure identified during training. Restarting all workers and "
Expand Down Expand Up @@ -521,13 +523,14 @@ def _restart(self):
def _increment_failures(self):
self._num_failures += 1
if self._num_failures >= self._max_failures:
_ray_start_tb = True # noqa: F841
krfricke marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Training has failed even after "
"Training has failed after "
f"{self._num_failures} "
"attempts. You can change the number of max "
"failure attempts by setting the "
"`max_retries` arg in your `Trainer`."
) from None
) from self._last_failure

def get_worker_group(self):
return self.worker_group
Expand Down
33 changes: 26 additions & 7 deletions python/ray/train/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

import ray
from ray.air._internal.util import find_free_port
from ray.air._internal.util import find_free_port, shorten_tb
from ray.actor import ActorHandle
from ray.exceptions import RayActorError
from ray.types import ObjectRef
Expand All @@ -29,13 +29,16 @@
logger = logging.getLogger(__name__)


def check_for_failure(remote_values: List[ObjectRef]) -> bool:
def check_for_failure(
remote_values: List[ObjectRef],
) -> Tuple[bool, Optional[Exception]]:
"""Check for actor failure when retrieving the remote values.

Args:
remote_values: List of object references from Ray actor methods.

Returns:
A tuple of (bool, Exception). The bool is
True if evaluating all object references is successful, False otherwise.
"""
unfinished = remote_values.copy()
Expand All @@ -51,12 +54,17 @@ def check_for_failure(remote_values: List[ObjectRef]) -> bool:
try:
ray.get(object_ref)
except RayActorError as exc:
logger.exception(str(exc))
failed_actor_rank = remote_values.index(object_ref)
logger.info(f"Worker {failed_actor_rank} has failed.")
return False
return False, exc
except Exception as exc:
# Other (e.g. training) errors should be directly raised
_ray_start_tb = True # noqa: F841
raise exc.with_traceback(
shorten_tb(exc.__traceback__, attr="_ray_start_tb")
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @amogkam if we just catch RayTaskErrors in addition to RayActorErrors this will lead to test failures as training is restarted and a different exception is raised.

What is the expected behavior here? IMO it looks like task errors should fail immediately (as it's likely a logic/syntax error) and only actor failures should be retried. If that's the case (as in the current implementation) maybe we can add better comments for this. Lmk


return True
return True, None


def get_address_and_port() -> Tuple[str, int]:
Expand Down Expand Up @@ -152,9 +160,20 @@ def discard_return_wrapper(*args, **kwargs):
raise ValueError(err_msg)
elif num_params == 1:
config = {} if config is None else config
return lambda: wrapped_train_func(config)

@functools.wraps(wrapped_train_func)
def train_fn():
_ray_start_tb = True # noqa: F841
return wrapped_train_func(config)

else: # num_params == 0
return wrapped_train_func

@functools.wraps(wrapped_train_func)
def train_fn():
_ray_start_tb = True # noqa: F841
return wrapped_train_func()

return train_fn


class Singleton(abc.ABCMeta):
Expand Down
6 changes: 5 additions & 1 deletion python/ray/train/_internal/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ray
from ray.actor import ActorHandle
from ray.air._internal.util import shorten_tb
from ray.types import ObjectRef
from ray.util.placement_group import PlacementGroup

Expand All @@ -23,7 +24,10 @@ def __execute(self, func: Callable[..., T], *args, **kwargs) -> T:
func: The function to execute.
args, kwargs: The arguments to pass into func.
"""
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as e:
raise e.with_traceback(shorten_tb(e.__traceback__, attr="_ray_start_tb"))


@dataclass
Expand Down
8 changes: 6 additions & 2 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def __init__(
"fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}."
)

self._print_trial_errors = bool(
int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1"))
)

self._server = None
self._server_port = server_port
if server_port is not None:
Expand Down Expand Up @@ -974,10 +978,10 @@ def _post_process_on_training_saving_result(self, trial):
def _on_executor_error(self, trial, e: Union[RayTaskError, TuneError]):
error_msg = f"Trial {trial}: Error processing event."
if self._fail_fast == TrialRunner.RAISE:
logger.error(error_msg, exc_info=e)
raise e
else:
logger.exception(error_msg, exc_info=e)
if self._print_trial_errors:
logger.error(error_msg, exc_info=e)
self._process_trial_failure(trial, exc=e)

def get_trial(self, tid):
Expand Down
8 changes: 5 additions & 3 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type, Union

from ray.air._internal.util import shorten_tb
from ray.tune.resources import Resources
from six.moves import queue

Expand Down Expand Up @@ -290,12 +291,12 @@ def run(self):
except StopIteration:
logger.debug(
(
"Thread runner raised StopIteration. Interperting it as a "
"Thread runner raised StopIteration. Interpreting it as a "
"signal to terminate the thread without error."
)
)
except Exception as e:
logger.exception("Runner Thread raised error.")
logger.error("Runner Thread raised error")
try:
# report the error but avoid indefinite blocking which would
# prevent the exception from being propagated in the unlikely
Expand Down Expand Up @@ -359,6 +360,7 @@ def _trainable_func(self, config, reporter, checkpoint_dir):

def _start(self):
def entrypoint():
_ray_start_tb = True # noqa: F841
return self._trainable_func(
self.config,
self._status_reporter,
Expand Down Expand Up @@ -586,7 +588,7 @@ def reset_config(self, new_config):
def _report_thread_runner_error(self, block=False):
try:
e = self._error_queue.get(block=block, timeout=ERROR_FETCH_TIMEOUT)
raise e
raise e.with_traceback(shorten_tb(e.__traceback__, attr="_ray_start_tb"))
except queue.Empty:
pass

Expand Down