Skip to content

Commit

Permalink
AIP-72: Add escalation path for killing process in Supervisor (apache…
Browse files Browse the repository at this point in the history
…#44465)

Added logic to support escalation from SIGINT to SIGTERM and SIGKILL when killing a task process from the new Supervisor process.

part of apache#44356
  • Loading branch information
kaxil authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 34dae01 commit bac8a14
Show file tree
Hide file tree
Showing 2 changed files with 276 additions and 17 deletions.
79 changes: 63 additions & 16 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
)

if TYPE_CHECKING:
from selectors import SelectorKey

from structlog.typing import FilteringBoundLogger, WrappedLogger


Expand Down Expand Up @@ -403,12 +401,53 @@ def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str],
self.stdin.write(msg.model_dump_json().encode())
self.stdin.write(b"\n")

def kill(self, signal: signal.Signals = signal.SIGINT):
def kill(
self,
signal_to_send: signal.Signals = signal.SIGINT,
escalation_delay: float = 5.0,
force: bool = False,
):
"""
Attempt to terminate the subprocess with a given signal.
If the process does not exit within `escalation_delay` seconds, escalate to SIGTERM and eventually SIGKILL if necessary.
:param signal_to_send: The signal to send initially (default is SIGINT).
:param escalation_delay: Time in seconds to wait before escalating to a stronger signal.
:param force: If True, ensure escalation through all signals without skipping.
"""
if self._exit_code is not None:
return

with suppress(ProcessLookupError):
os.kill(self.pid, signal)
# Escalation sequence: SIGINT -> SIGTERM -> SIGKILL
escalation_path = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL]

if force and signal_to_send in escalation_path:
# Start from `signal_to_send` and escalate to the end of the escalation path
escalation_path = escalation_path[escalation_path.index(signal_to_send) :]
else:
escalation_path = [signal_to_send]

for sig in escalation_path:
try:
self._process.send_signal(sig)

# Service subprocess events during the escalation delay
self._service_subprocess(max_wait_time=escalation_delay, raise_on_timeout=True)
if self._exit_code is not None:
log.info("Process exited", pid=self.pid, exit_code=self._exit_code, signal=sig.name)
return
except psutil.TimeoutExpired:
msg = "Process did not terminate in time"
if sig != escalation_path[-1]:
msg += "; escalating"
log.warning(msg, pid=self.pid, signal=sig.name)
except psutil.NoSuchProcess:
log.debug("Process already terminated", pid=self.pid)
self._exit_code = -1
return

log.error("Failed to terminate process after full escalation", pid=self.pid)

def wait(self) -> int:
if self._exit_code is not None:
Expand Down Expand Up @@ -453,20 +492,23 @@ def _monitor_subprocess(self):
)
# Block until events are ready or the timeout is reached
# This listens for activity (e.g., subprocess output) on registered file objects
events = self.selector.select(timeout=max_wait_time)
self._process_file_object_events(events)
self._service_subprocess(max_wait_time=max_wait_time)

self._check_subprocess_exit()
self._send_heartbeat_if_needed()

def _process_file_object_events(self, events: list[tuple[SelectorKey, int]]):
def _service_subprocess(self, max_wait_time: float, raise_on_timeout: bool = False):
"""
Process selector events by invoking handlers for each file object.
Service subprocess events by processing socket activity and checking for process exit.
This method:
- Waits for activity on the registered file objects (via `self.selector.select`).
- Processes any events triggered on these file objects.
- Checks if the subprocess has exited during the wait.
For each file object event, this method retrieves the associated handler and processes
the event. If the handler indicates that the file object no longer needs
monitoring (e.g., EOF or closed), the file object is unregistered and closed.
:param max_wait_time: Maximum time to block while waiting for events, in seconds.
:param raise_on_timeout: If True, raise an exception if the subprocess does not exit within the timeout.
"""
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
# Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
socket_handler = key.data
Expand All @@ -484,13 +526,18 @@ def _process_file_object_events(self, events: list[tuple[SelectorKey, int]]):
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]

def _check_subprocess_exit(self):
# Check if the subprocess has exited
self._check_subprocess_exit(raise_on_timeout=raise_on_timeout)

def _check_subprocess_exit(self, raise_on_timeout: bool = False):
"""Check if the subprocess has exited."""
if self._exit_code is None:
try:
self._exit_code = self._process.wait(timeout=0)
log.debug("Task process exited", exit_code=self._exit_code)
except psutil.TimeoutExpired:
if raise_on_timeout:
raise
pass

def _send_heartbeat_if_needed(self):
Expand All @@ -514,7 +561,7 @@ def _send_heartbeat_if_needed(self):
detail=e.detail,
status_code=e.response.status_code,
)
self.kill(signal.SIGTERM)
self.kill(signal.SIGTERM, force=True)
else:
# If we get any other error, we'll just log it and try again next time
self._handle_heartbeat_failures()
Expand All @@ -536,7 +583,7 @@ def _handle_heartbeat_failures(self):
log.error(
"Too many failed heartbeats; terminating process", failed_heartbeats=self.failed_heartbeats
)
self.kill(signal.SIGTERM)
self.kill(signal.SIGTERM, force=True)

@property
def final_state(self):
Expand Down
214 changes: 213 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import inspect
import logging
import os
import selectors
import signal
import sys
from io import BytesIO
Expand All @@ -29,6 +30,7 @@
from unittest.mock import MagicMock

import httpx
import psutil
import pytest
from uuid6 import uuid7

Expand Down Expand Up @@ -419,7 +421,7 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t
proc._send_heartbeat_if_needed()

assert proc.failed_heartbeats == max_failed_heartbeats
mock_kill.assert_called_once_with(signal.SIGTERM)
mock_kill.assert_called_once_with(signal.SIGTERM, force=True)
mock_client_heartbeat.assert_called_with(TI_ID, pid=mock_process.pid)
assert {
"event": "Too many failed heartbeats; terminating process",
Expand All @@ -430,6 +432,216 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t
} in captured_logs


class TestWatchedSubprocessKill:
@pytest.fixture
def mock_process(self, mocker):
process = mocker.Mock(spec=psutil.Process)
process.pid = 12345
return process

@pytest.fixture
def watched_subprocess(self, mocker, mock_process):
proc = WatchedSubprocess(
ti_id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
client=mocker.Mock(),
process=mock_process,
)
# Mock the selector
mock_selector = mocker.Mock(spec=selectors.DefaultSelector)
mock_selector.select.return_value = []

# Set the selector on the process
proc.selector = mock_selector
return proc

@pytest.mark.parametrize(
["signal_to_send", "wait_side_effect", "expected_signals"],
[
pytest.param(
signal.SIGINT,
[0],
[signal.SIGINT],
id="SIGINT-success-without-escalation",
),
pytest.param(
signal.SIGINT,
[psutil.TimeoutExpired(0.1), 0],
[signal.SIGINT, signal.SIGTERM],
id="SIGINT-escalates-to-SIGTERM",
),
pytest.param(
signal.SIGINT,
[
psutil.TimeoutExpired(0.1), # SIGINT times out
psutil.TimeoutExpired(0.1), # SIGTERM times out
0, # SIGKILL succeeds
],
[signal.SIGINT, signal.SIGTERM, signal.SIGKILL],
id="SIGINT-escalates-to-SIGTERM-then-SIGKILL",
),
pytest.param(
signal.SIGTERM,
[
psutil.TimeoutExpired(0.1), # SIGTERM times out
0, # SIGKILL succeeds
],
[signal.SIGTERM, signal.SIGKILL],
id="SIGTERM-escalates-to-SIGKILL",
),
pytest.param(
signal.SIGKILL,
[0],
[signal.SIGKILL],
id="SIGKILL-success-without-escalation",
),
],
)
def test_force_kill_escalation(
self,
watched_subprocess,
mock_process,
mocker,
signal_to_send,
wait_side_effect,
expected_signals,
captured_logs,
):
"""Test escalation path for SIGINT, SIGTERM, and SIGKILL when force=True."""
# Mock the process wait method to return the exit code or raise an exception
mock_process.wait.side_effect = wait_side_effect

watched_subprocess.kill(signal_to_send=signal_to_send, escalation_delay=0.1, force=True)

# Check that the correct signals were sent
mock_process.send_signal.assert_has_calls([mocker.call(sig) for sig in expected_signals])

# Check that the process was waited on for each signal
mock_process.wait.assert_has_calls([mocker.call(timeout=0)] * len(expected_signals))

## Validate log messages
# If escalation occurred, we should see a warning log for each signal sent
if len(expected_signals) > 1:
assert {
"event": "Process did not terminate in time; escalating",
"level": "warning",
"logger": "supervisor",
"pid": 12345,
"signal": expected_signals[-2].name,
"timestamp": mocker.ANY,
} in captured_logs

# Regardless of escalation, we should see an info log for the final signal sent
assert {
"event": "Process exited",
"level": "info",
"logger": "supervisor",
"pid": 12345,
"signal": expected_signals[-1].name,
"exit_code": 0,
"timestamp": mocker.ANY,
} in captured_logs

# Validate `selector.select` calls
assert watched_subprocess.selector.select.call_count == len(expected_signals)
watched_subprocess.selector.select.assert_has_calls(
[mocker.call(timeout=0.1)] * len(expected_signals)
)

assert watched_subprocess._exit_code == 0

def test_force_kill_with_selector_events(self, watched_subprocess, mock_process, mocker):
"""Test force escalation with selector events handled during wait."""
# Mock selector to return events during escalation
mock_key = mocker.Mock()
mock_key.fileobj = mocker.Mock()

# Simulate EOF
mock_key.data = mocker.Mock(return_value=False)

watched_subprocess.selector.select.side_effect = [
[(mock_key, None)], # Event during SIGINT
[], # No event during SIGTERM
[(mock_key, None)], # Event during SIGKILL
]

mock_process.wait.side_effect = [
psutil.TimeoutExpired(0.1), # SIGINT times out
psutil.TimeoutExpired(0.1), # SIGTERM times out
0, # SIGKILL succeeds
]

watched_subprocess.kill(signal.SIGINT, escalation_delay=0.1, force=True)

# Validate selector interactions
assert watched_subprocess.selector.select.call_count == 3
mock_key.data.assert_has_calls([mocker.call(mock_key.fileobj), mocker.call(mock_key.fileobj)])

# Validate signal escalation
mock_process.send_signal.assert_has_calls(
[mocker.call(signal.SIGINT), mocker.call(signal.SIGTERM), mocker.call(signal.SIGKILL)]
)

def test_kill_process_already_exited(self, watched_subprocess, mock_process):
"""Test behavior when the process has already exited."""
mock_process.wait.side_effect = psutil.NoSuchProcess(pid=1234)

watched_subprocess.kill(signal.SIGINT, force=True)

mock_process.send_signal.assert_called_once_with(signal.SIGINT)
mock_process.wait.assert_called_once()
assert watched_subprocess._exit_code == -1

def test_kill_process_custom_signal(self, watched_subprocess, mock_process):
"""Test that the process is killed with the correct signal."""
mock_process.wait.return_value = 0

signal_to_send = signal.SIGUSR1
watched_subprocess.kill(signal_to_send, force=False)

mock_process.send_signal.assert_called_once_with(signal_to_send)
mock_process.wait.assert_called_once_with(timeout=0)

def test_service_subprocess(self, watched_subprocess, mock_process, mocker):
"""Test `_service_subprocess` processes selector events and handles subprocess exit."""
## Given

# Mock file objects and handlers
mock_stdout = mocker.Mock()
mock_stderr = mocker.Mock()

# Handlers for stdout and stderr
mock_stdout_handler = mocker.Mock(return_value=False) # Simulate EOF for stdout
mock_stderr_handler = mocker.Mock(return_value=True) # Continue processing for stderr

# Mock selector to return events
mock_key_stdout = mocker.Mock(fileobj=mock_stdout, data=mock_stdout_handler)
mock_key_stderr = mocker.Mock(fileobj=mock_stderr, data=mock_stderr_handler)
watched_subprocess.selector.select.return_value = [(mock_key_stdout, None), (mock_key_stderr, None)]

# Mock to simulate process exited successfully
mock_process.wait.return_value = 0

## Our actual test
watched_subprocess._service_subprocess(max_wait_time=1.0)

## Validations!
# Validate selector interactions
watched_subprocess.selector.select.assert_called_once_with(timeout=1.0)

# Validate handler calls
mock_stdout_handler.assert_called_once_with(mock_stdout)
mock_stderr_handler.assert_called_once_with(mock_stderr)

# Validate unregistering and closing of EOF file object
watched_subprocess.selector.unregister.assert_called_once_with(mock_stdout)
mock_stdout.close.assert_called_once()

# Validate that `_check_subprocess_exit` is called
mock_process.wait.assert_called_once_with(timeout=0)


class TestHandleRequest:
@pytest.fixture
def watched_subprocess(self, mocker):
Expand Down

0 comments on commit bac8a14

Please sign in to comment.