Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune/internal] Move signal handling into separate method #31004

Merged
merged 4 commits into from
Dec 12, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 72 additions & 64 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,39 @@ def _report_progress(
reporter.report(trials, done, sched_debug_str, executor_debug_str)


def _setup_signal_catching() -> threading.Event:
original_handler = signal.getsignal(signal.SIGINT)
stop_event = threading.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more specific name? tune_loop_interrupted_event or something alike?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're ok with it I'd like to keep this, as we set up general signal catching here and not just the event. Let me know if you disagree!
Meanwhile I'll kick of tune cloud release tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I mean the event name can be more specific? The method name _setup_signal_catching sounds fine to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! Yes, thanks, will updated!


def signal_interrupt_tune_run(sig: int, frame):
logger.warning(
"Stop signal received (e.g. via SIGINT/Ctrl+C), ending Ray Tune run. "
"This will try to checkpoint the experiment state one last time. "
"Press CTRL+C (or send SIGINT/SIGKILL/SIGTERM) "
"to skip. "
)
stop_event.set()
# Restore original signal handler to react to future SIGINT signals
signal.signal(signal.SIGINT, original_handler)

# We should only install the handler when it is safe to do so.
# When tune.run() is called from worker thread, signal.signal will
# fail.
allow_signal_catching = True
if threading.current_thread() != threading.main_thread():
allow_signal_catching = False

if allow_signal_catching:
if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
signal.signal(signal.SIGINT, signal_interrupt_tune_run)

# Always register SIGUSR1 if available (not available e.g. on Windows)
if hasattr(signal, "SIGUSR1"):
signal.signal(signal.SIGUSR1, signal_interrupt_tune_run)

return stop_event


@PublicAPI
def run(
run_or_experiment: Union[str, Callable, Type],
Expand Down Expand Up @@ -507,11 +540,6 @@ class and registered trainables.
"well as implementing `reset_config` for Trainable."
)

trial_executor = trial_executor or RayTrialExecutor(
reuse_actors=reuse_actors,
result_buffer_length=result_buffer_length,
chdir_to_trial_dir=chdir_to_trial_dir,
)
if isinstance(run_or_experiment, list):
experiments = run_or_experiment
else:
Expand Down Expand Up @@ -627,6 +655,36 @@ class and registered trainables.
callbacks, sync_config, metric=metric, progress_metrics=progress_metrics
)

# User Warning for GPUs
if ray.cluster_resources().get("GPU", 0):
if _check_gpus_in_resources(resources=resources_per_trial):
# "gpu" is manually set.
pass
elif _check_default_resources_override(experiments[0].run_identifier):
# "default_resources" is manually overridden.
pass
else:
logger.warning(
"Tune detects GPUs, but no trials are using GPUs. "
"To enable trials to use GPUs, wrap `train_func` with "
"`tune.with_resources(train_func, resources_per_trial={'gpu': 1})` "
"which allows Tune to expose 1 GPU to each trial. "
"For Ray AIR Trainers, you can specify GPU resources "
"through `ScalingConfig(use_gpu=True)`. "
"You can also override "
"`Trainable.default_resource_request` if using the "
"Trainable API."
)

stop_event = _setup_signal_catching()

progress_reporter = progress_reporter or _detect_reporter()

trial_executor = trial_executor or RayTrialExecutor(
reuse_actors=reuse_actors,
result_buffer_length=result_buffer_length,
chdir_to_trial_dir=chdir_to_trial_dir,
)
runner = TrialRunner(
search_alg=search_alg,
scheduler=scheduler,
Expand Down Expand Up @@ -662,58 +720,6 @@ class and registered trainables.
experiments=experiments, total_num_samples=search_alg.total_samples
)

# User Warning for GPUs
if trial_executor.has_gpus():
if _check_gpus_in_resources(resources=resources_per_trial):
# "gpu" is manually set.
pass
elif _check_default_resources_override(experiments[0].run_identifier):
# "default_resources" is manually overridden.
pass
else:
logger.warning(
"Tune detects GPUs, but no trials are using GPUs. "
"To enable trials to use GPUs, wrap `train_func` with "
"`tune.with_resources(train_func, resources_per_trial={'gpu': 1})` "
"which allows Tune to expose 1 GPU to each trial. "
"For Ray AIR Trainers, you can specify GPU resources "
"through `ScalingConfig(use_gpu=True)`. "
"You can also override "
"`Trainable.default_resource_request` if using the "
"Trainable API."
)

original_handler = signal.getsignal(signal.SIGINT)
state = {"signal": None}

def signal_interrupt_tune_run(sig: int, frame):
logger.warning(
"Stop signal received (e.g. via SIGINT/Ctrl+C), ending Ray Tune run. "
"This will try to checkpoint the experiment state one last time. "
"Press CTRL+C (or send SIGINT/SIGKILL/SIGTERM) "
"to skip. "
)
state["signal"] = sig
# Restore original signal handler to react to future SIGINT signals
signal.signal(signal.SIGINT, original_handler)

# We should only install the handler when it is safe to do so.
# When tune.run() is called from worker thread, signal.signal will
# fail.
allow_signal_catching = True
if threading.current_thread() != threading.main_thread():
allow_signal_catching = False

if allow_signal_catching:
if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
signal.signal(signal.SIGINT, signal_interrupt_tune_run)

# Always register SIGUSR1 if available (not available e.g. on Windows)
if hasattr(signal, "SIGUSR1"):
signal.signal(signal.SIGUSR1, signal_interrupt_tune_run)

progress_reporter = progress_reporter or _detect_reporter()

tune_start = time.time()

progress_reporter.setup(
Expand All @@ -722,7 +728,7 @@ def signal_interrupt_tune_run(sig: int, frame):
metric=metric,
mode=mode,
)
while not runner.is_finished() and not state["signal"]:
while not runner.is_finished() and not stop_event.is_set():
runner.step()
if has_verbosity(Verbosity.V1_EXPERIMENT):
_report_progress(runner, progress_reporter)
Expand All @@ -736,6 +742,9 @@ def signal_interrupt_tune_run(sig: int, frame):
if has_verbosity(Verbosity.V1_EXPERIMENT):
_report_progress(runner, progress_reporter, done=True)

all_trials = runner.get_trials()
experiment_checkpoint = runner.checkpoint_file

# Wait for syncing to finish
for callback in callbacks:
if isinstance(callback, SyncerCallback):
Expand All @@ -747,12 +756,12 @@ def signal_interrupt_tune_run(sig: int, frame):
runner.cleanup()

incomplete_trials = []
for trial in runner.get_trials():
for trial in all_trials:
if trial.status != Trial.TERMINATED:
incomplete_trials += [trial]

if incomplete_trials:
if raise_on_failed_trial and not state["signal"]:
if raise_on_failed_trial and not stop_event.is_set():
raise TuneError("Trials did not complete", incomplete_trials)
else:
logger.error("Trials did not complete: %s", incomplete_trials)
Expand All @@ -764,17 +773,16 @@ def signal_interrupt_tune_run(sig: int, frame):
f"({tune_taken:.2f} seconds for the tuning loop)."
)

if state["signal"]:
if stop_event.is_set():
logger.warning(
"Experiment has been interrupted, but the most recent state was "
"saved. You can continue running this experiment by passing "
"`resume=True` to `tune.run()`"
)

trials = runner.get_trials()
return ExperimentAnalysis(
runner.checkpoint_file,
trials=trials,
experiment_checkpoint,
trials=all_trials,
default_metric=metric,
default_mode=mode,
sync_config=sync_config,
Expand Down