From ef6a03721d89fb8d3887a64125f68e61368c34f0 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 25 Apr 2023 10:44:35 -0700 Subject: [PATCH] signal handling: User-defined interrupt handlers Interrupt handling is a tricky problem, not just in terms of implementation, but in terms of desired behavior: when an interrupt is received, which code should handle it? Julia's current answer to this is effectively to throw an `InterruptException` to the first task to hit a safepoint. While this seems sensible (the code that's running gets interrupted), it only really works for very basic numerical code. In the case that multiple tasks are running concurrently, or when try-catch handlers are registered, this system breaks down, and results in unpredictable behavior. This unpredictable behavior includes: - Interrupting background/runtime tasks which don't want to be interrupted, as they do little bits of important work (and are critical to library runtime functionality) - Interrupting only one task, when multiple coordinating tasks would want to receive the interrupt to safely terminate a computation - Interrupting only one library's task, when multiple libraries really would want to be notified about the interrupt The above behavior makes it nearly impossible to provide reliable Ctrl-C behavior, and results in very confused users who get stuck hitting Ctrl-C continuously, sometimes getting caught in a hang, sometimes triggering unrelated exception handling code they didn't mean to, sometimes getting a segfault, and very rarely getting the behavior they desire (with unpredictable safety of being able to continue using the active session as intended). This commit provides an alternative behavior for interrupts which is more predictable: user code may now register tasks as "interrupt handlers", which will be guaranteed to receive an `InterruptException` whenever the session receives an interrupt signal. Additionally, when any interrupt handlers are registered, no other tasks will receive `InterruptException`s; only the handlers may receive them. This behavior allows one or more libraries to register handler tasks which will all be concurrently awoken to handle each interrupt and do whatever is necessary to safely interrupt any running code; the extent to which other tasks are interrupted is arbitrary and library-defined. For example, GPU libraries like AMDGPU.jl can register a handler to safely interrupt GPU kernels running on all GPU queues and do resource cleanup. Concurrently, a complex runtime like the scheduler in Dagger.jl can register a handler to interrupt running tasks on other workers when possible, and otherwise notify the user that tasks are being shutdown. This change is intended to be non-breaking for simple codes: the previous behavior is maintained when no interrupt handlers are registered. However, once some libraries start adding interrupt handlers, other libraries will need to follow suit to ensure that users can interrupt their computations. --- base/task.jl | 14 +++++++++++ src/gc.c | 1 + src/jl_exported_data.inc | 1 + src/jl_exported_funcs.inc | 1 + src/julia_threads.h | 2 ++ src/signal-handling.c | 49 +++++++++++++++++++++++++++++++++++++++ src/signals-unix.c | 11 +++++++-- src/signals-win.c | 3 ++- src/task.c | 7 ++++-- src/threading.h | 2 ++ 10 files changed, 86 insertions(+), 5 deletions(-) diff --git a/base/task.jl b/base/task.jl index e407cbd62bbd68..21ee86665c78c1 100644 --- a/base/task.jl +++ b/base/task.jl @@ -992,3 +992,17 @@ if Sys.iswindows() else pause() = ccall(:pause, Cvoid, ()) end + +interrupt_handlers() = ccall(:jl_get_interrupt_handlers, Any, ())::Vector{Task} +function register_interrupt_handler(t::Task) + handlers = interrupt_handlers() + if findfirst(==(t), handlers) === nothing + push!(handlers, t) + end + return +end +function unregister_interrupt_handler(t::Task) + handlers = interrupt_handlers() + deleteat!(handlers, findall(==(t), handlers)) + return +end diff --git a/src/gc.c b/src/gc.c index 3c116b4cd352f0..653f5623385f04 100644 --- a/src/gc.c +++ b/src/gc.c @@ -2891,6 +2891,7 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq) gc_try_claim_and_push(mq, jl_emptytuple_type, NULL); gc_try_claim_and_push(mq, cmpswap_names, NULL); gc_try_claim_and_push(mq, jl_global_roots_table, NULL); + gc_try_claim_and_push(mq, jl_interrupt_handlers, NULL); } // find unmarked objects that need to be finalized from the finalizer list "list". diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 52f6cb11d8c0f6..d388f5473410ae 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -54,6 +54,7 @@ XX(jl_int8_type) \ XX(jl_interconditional_type) \ XX(jl_interrupt_exception) \ + XX(jl_interrupt_handlers) \ XX(jl_intrinsic_type) \ XX(jl_kwcall_func) \ XX(jl_lineinfonode_type) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 02355d70036050..edf77ebea2bed0 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -220,6 +220,7 @@ XX(jl_get_field) \ XX(jl_get_global) \ XX(jl_get_image_file) \ + XX(jl_get_interrupt_handlers) \ XX(jl_get_JIT) \ XX(jl_get_julia_bin) \ XX(jl_get_julia_bindir) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index 6439caa0aa2eed..bb14d503d270b9 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -369,6 +369,8 @@ JL_DLLEXPORT int8_t jl_gc_is_in_finalizer(void); JL_DLLEXPORT void jl_wakeup_thread(int16_t tid); +JL_DLLEXPORT void jl_schedule_task(struct _jl_task_t *task); + #ifdef __cplusplus } #endif diff --git a/src/signal-handling.c b/src/signal-handling.c index e241fd22ecb186..25a5648229d805 100644 --- a/src/signal-handling.c +++ b/src/signal-handling.c @@ -304,6 +304,55 @@ static void jl_check_profile_autostop(void) } } +JL_DLLEXPORT _Atomic(jl_array_t *) jl_interrupt_handlers = NULL; +JL_DLLEXPORT jl_array_t *jl_get_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) { + static jl_datatype_t *jl_array_task_type; + if (!jl_array_task_type) + jl_array_task_type = (jl_datatype_t *)jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_task_type, jl_box_long(1)); + jl_array_t *new_handlers = jl_alloc_array_1d((jl_value_t *)jl_array_task_type, 0); + if (jl_atomic_cmpswap(&jl_interrupt_handlers, &handlers, new_handlers)) { + handlers = new_handlers; + } else { + handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + } + } + assert(handlers); + return handlers; +} +static _Atomic(int) handle_interrupt = 0; +JL_DLLEXPORT void jl_schedule_interrupt_handlers(void) +{ + if (jl_atomic_exchange_relaxed(&handle_interrupt, 0) != 1) + return; + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (!handlers) + return; + for (int i = 0; i < jl_array_len(handlers); i++) { + jl_task_t *handler = ((jl_task_t **)jl_array_data(handlers))[i]; + assert(jl_is_task(handler)); + if (handler->ptls) + continue; + if (jl_atomic_load_relaxed(&handler->_state) != JL_TASK_STATE_RUNNABLE) + continue; + handler->result = jl_interrupt_exception; + handler->_isexception = 1; + jl_schedule_task(handler); + } +} +static int want_interrupt_handlers(void) +{ + jl_array_t *handlers = jl_atomic_load_relaxed(&jl_interrupt_handlers); + if (handlers && (jl_array_len(handlers) > 0)) { + // Set flag to trigger user handlers on next task switch + jl_atomic_store_relaxed(&handle_interrupt, 1); + return 1; + } + return 0; +} + #if defined(_WIN32) #include "signals-win.c" #else diff --git a/src/signals-unix.c b/src/signals-unix.c index 2858538372722d..e2c68c53064bed 100644 --- a/src/signals-unix.c +++ b/src/signals-unix.c @@ -525,11 +525,14 @@ void usr2_handler(int sig, siginfo_t *info, void *ctx) jl_atomic_exchange(&ptls->signal_request, 0); // returns -1 if (request == 2) { int force = jl_check_force_sigint(); + if (!force && want_interrupt_handlers()) { + return; + } if (force || (!ptls->defer_signal && ptls->io_wait)) { jl_safepoint_consume_sigint(); + // Force a throw if (force) jl_safe_printf("WARNING: Force throwing a SIGINT\n"); - // Force a throw jl_clear_force_sigint(); jl_throw_in_ctx(ct, jl_interrupt_exception, sig, ctx); } @@ -767,7 +770,7 @@ static void *signal_listener(void *arg) profile = (sig == SIGUSR1); #if defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L if (profile && !(info.si_code == SI_TIMER && - info.si_value.sival_ptr == &timerprof)) + info.si_value.sival_ptr == &timerprof)) profile = 0; #endif #endif @@ -780,6 +783,10 @@ static void *signal_listener(void *arg) else if (exit_on_sigint) { critical = 1; } + // FIXME: Skip this if force + else if (want_interrupt_handlers()) { + continue; + } else { jl_try_deliver_sigint(); continue; diff --git a/src/signals-win.c b/src/signals-win.c index 5dd6b34558ca6d..5429c62f8f9aaa 100644 --- a/src/signals-win.c +++ b/src/signals-win.c @@ -221,7 +221,8 @@ static BOOL WINAPI sigint_handler(DWORD wsig) //This needs winapi types to guara if (!jl_ignore_sigint()) { if (exit_on_sigint) jl_exit(128 + sig); // 128 + SIGINT - jl_try_deliver_sigint(); + if (!want_interrupt_handlers()) + jl_try_deliver_sigint(); } return 1; } diff --git a/src/task.c b/src/task.c index 123cfaac001630..eb1461c745ae5d 100644 --- a/src/task.c +++ b/src/task.c @@ -621,8 +621,12 @@ JL_NO_ASAN static void ctx_switch(jl_task_t *lastt) sanitizer_finish_switch_fiber(ptls->previous_task, jl_atomic_load_relaxed(&ptls->current_task)); } +JL_DLLIMPORT void jl_schedule_interrupt_handlers(void); + JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER { + jl_schedule_interrupt_handlers(); + jl_task_t *ct = jl_current_task; jl_ptls_t ptls = ct->ptls; jl_task_t *t = ptls->next_task; @@ -1164,7 +1168,7 @@ JL_DLLEXPORT void jl_task_wait() jl_apply(&wait_func, 1); ct->world_age = last_age; } - +#endif JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) { static jl_function_t *sched_func = NULL; @@ -1178,7 +1182,6 @@ JL_DLLEXPORT void jl_schedule_task(jl_task_t *task) jl_apply(args, 2); ct->world_age = last_age; } -#endif // Do one-time initializations for task system void jl_init_tasks(void) JL_GC_DISABLED diff --git a/src/threading.h b/src/threading.h index 4df6815124eb9c..efbc25f750cb08 100644 --- a/src/threading.h +++ b/src/threading.h @@ -27,6 +27,8 @@ jl_ptls_t jl_init_threadtls(int16_t tid) JL_NOTSAFEPOINT; void jl_init_threadinginfra(void); void jl_threadfun(void *arg); +extern _Atomic(jl_array_t *) jl_interrupt_handlers; + #ifdef __cplusplus } #endif