From 8899efede0cd797d84e2aee1390a48ac54f23b83 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:39:34 -0800 Subject: [PATCH] Deferrable sensors can implement sensor timeout (#33718) The goal here is to ensure behavioral parity w.r.t. sensor timeouts between deferrable and non-deferrable sensor operators. With non-deferrable sensors, if there's a sensor timeout, the task fails without retry. But currently, with deferrable sensors, that does not happen. Since there's already a "timeout" capability on triggers, we can use this for sensor timeout. Essentially all that was missing was the ability to distinguish between trigger timeouts and other trigger errors. With this capability, base sensor can distinguish between the two, and reraise deferral timeouts as sensor timeouts. So, here we add a new exception type, TaskDeferralTimeout, which base sensor reraises as AirflowSensorTimeout. Then, to take advantage of this feature, a sensor need only ensure that its timeout is passed when deferring. For convenience, we update the task deferred exception signature to take int and float in addition to timedelta, since that's how `timeout` attr is defined on base sensor. But we do not change the exception attribute type. In order to keep this PR focused, this PR only updates one sensor to use the timeout functionality, namely, time delta sensor. Other sensors will have to be done as followups. --- airflow/exceptions.py | 14 +++++++-- airflow/jobs/scheduler_job_runner.py | 5 ++-- airflow/models/baseoperator.py | 11 +++++-- airflow/models/taskinstance.py | 1 + airflow/models/trigger.py | 29 +++++++++++++++++-- airflow/sensors/base.py | 5 +++- .../providers/standard/sensors/time_delta.py | 21 +++++++++++++- tests/models/test_baseoperator.py | 12 +++++++- tests/sensors/test_base.py | 10 +++++++ 9 files changed, 95 insertions(+), 13 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index fee0b5a671d54e..4035488cf87e1d 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -22,13 +22,13 @@ from __future__ import annotations import warnings +from datetime import timedelta from http import HTTPStatus from typing import TYPE_CHECKING, Any, NamedTuple from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: - import datetime from collections.abc import Sized from airflow.models import DagRun @@ -385,14 +385,18 @@ def __init__( trigger, method_name: str, kwargs: dict[str, Any] | None = None, - timeout: datetime.timedelta | None = None, + timeout: timedelta | int | float | None = None, ): super().__init__() self.trigger = trigger self.method_name = method_name self.kwargs = kwargs - self.timeout = timeout + self.timeout: timedelta | None # Check timeout type at runtime + if isinstance(timeout, (int, float)): + self.timeout = timedelta(seconds=timeout) + else: + self.timeout = timeout if self.timeout is not None and not hasattr(self.timeout, "total_seconds"): raise ValueError("Timeout value must be a timedelta") @@ -417,6 +421,10 @@ class TaskDeferralError(AirflowException): """Raised when a task failed during deferral for some reason.""" +class TaskDeferralTimeout(AirflowException): + """Raise when there is a timeout on the deferral.""" + + # The try/except handling is needed after we moved all k8s classes to cncf.kubernetes provider # These two exceptions are used internally by Kubernetes Executor but also by PodGenerator, so we need # to leave them here in case older version of cncf.kubernetes provider is used to run KubernetesPodOperator diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 56a65009e2b6aa..0dd6b32f741432 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -65,6 +65,7 @@ from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance +from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.timetables.simple import AssetTriggeredTimetable @@ -2057,8 +2058,8 @@ def check_trigger_timeouts( ) .values( state=TaskInstanceState.SCHEDULED, - next_method="__fail__", - next_kwargs={"error": "Trigger/execution timeout"}, + next_method=TRIGGER_FAIL_REPR, + next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT}, trigger_id=None, ) ).rowcount diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 13eb787b4f86b6..512eb189cc9a52 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -50,6 +50,7 @@ from airflow.exceptions import ( AirflowException, TaskDeferralError, + TaskDeferralTimeout, TaskDeferred, ) from airflow.lineage import apply_lineage, prepare_lineage @@ -72,6 +73,7 @@ from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin +from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason from airflow.sdk.definitions.baseoperator import ( BaseOperatorMeta as TaskSDKBaseOperatorMeta, get_merged_defaults, @@ -973,7 +975,7 @@ def defer( trigger: BaseTrigger, method_name: str, kwargs: dict[str, Any] | None = None, - timeout: timedelta | None = None, + timeout: timedelta | int | float | None = None, ) -> NoReturn: """ Mark this Operator "deferred", suspending its execution until the provided trigger fires an event. @@ -990,12 +992,15 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, """Call this method when a deferred task is resumed.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. - if next_method == "__fail__": + if next_method == TRIGGER_FAIL_REPR: next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: self.log.error("Trigger failed:\n%s", "\n".join(traceback)) - raise TaskDeferralError(next_kwargs.get("error", "Unknown")) + if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: + raise TaskDeferralTimeout(error) + else: + raise TaskDeferralError(error) # Grab the callable off the Operator/Task and add in any kwargs execute_callable = getattr(self, next_method) if next_kwargs: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d6b24f34000f4f..705cc797ed11b5 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1538,6 +1538,7 @@ def _defer_task( ) -> TaskInstance: from airflow.models.trigger import Trigger + timeout: timedelta | None if exception is not None: trigger_row = Trigger.from_object(exception.trigger) next_method = exception.method_name diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index b7b6ba9980d51b..f56512cdbc1835 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -18,6 +18,7 @@ import datetime from collections.abc import Iterable +from enum import Enum from traceback import format_exception from typing import TYPE_CHECKING, Any @@ -40,6 +41,27 @@ from airflow.triggers.base import BaseTrigger +TRIGGER_FAIL_REPR = "__fail__" +"""String value to represent trigger failure. + +Internal use only. + +:meta private: +""" + + +class TriggerFailureReason(str, Enum): + """ + Reasons for trigger failures. + + Internal use only. + + :meta private: + """ + + TRIGGER_TIMEOUT = "Trigger timeout" + TRIGGER_FAILURE = "Trigger failure" + class Trigger(Base): """ @@ -229,8 +251,11 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> ): # Add the error and set the next_method to the fail state traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None - task_instance.next_method = "__fail__" - task_instance.next_kwargs = {"error": "Trigger failure", "traceback": traceback} + task_instance.next_method = TRIGGER_FAIL_REPR + task_instance.next_kwargs = { + "error": TriggerFailureReason.TRIGGER_FAILURE, + "traceback": traceback, + } # Remove ourselves as its trigger task_instance.trigger_id = None # Finally, mark it as scheduled so it gets re-queued diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index a593a4519f11f8..e8d89b2365b2d8 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -38,6 +38,7 @@ AirflowSkipException, AirflowTaskTimeout, TaskDeferralError, + TaskDeferralTimeout, ) from airflow.executors.executor_loader import ExecutorLoader from airflow.models.baseoperator import BaseOperator @@ -174,7 +175,7 @@ def __init__( super().__init__(**kwargs) self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds() self.soft_fail = soft_fail - self.timeout = self._coerce_timeout(timeout).total_seconds() + self.timeout: int | float = self._coerce_timeout(timeout).total_seconds() self.mode = mode self.exponential_backoff = exponential_backoff self.max_wait = self._coerce_max_wait(max_wait) @@ -338,6 +339,8 @@ def run_duration() -> float: def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): try: return super().resume_execution(next_method, next_kwargs, context) + except TaskDeferralTimeout as e: + raise AirflowSensorTimeout(*e.args) from e except (AirflowException, TaskDeferralError) as e: if self.soft_fail: raise AirflowSkipException(str(e)) from e diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index a0d3189b027fd1..8e0f26ac249dd0 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -21,6 +21,8 @@ from time import sleep from typing import TYPE_CHECKING, Any, NoReturn +from packaging.version import Version + from airflow.configuration import conf from airflow.exceptions import AirflowSkipException from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger @@ -32,6 +34,12 @@ from airflow.utils.context import Context +def _get_airflow_version(): + from airflow import __version__ as airflow_version + + return Version(Version(airflow_version).base_version) + + class TimeDeltaSensor(BaseSensorOperator): """ Waits for a timedelta after the run's data interval. @@ -91,7 +99,18 @@ def execute(self, context: Context) -> bool | NoReturn: raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e raise - self.defer(trigger=trigger, method_name="execute_complete") + # todo: remove backcompat when min airflow version greater than 2.11 + timeout: int | float | timedelta + if _get_airflow_version() >= Version("2.11.0"): + timeout = self.timeout + else: + timeout = timedelta(seconds=self.timeout) + + self.defer( + trigger=trigger, + method_name="execute_complete", + timeout=timeout, + ) def execute_complete(self, context: Context, event: Any = None) -> None: """Handle the event when the trigger fires and return immediately.""" diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 638f012a3a5a02..e95866d95a5e95 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -29,7 +29,7 @@ import pytest from airflow.decorators import task as task_decorator -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferralTimeout from airflow.lineage.entities import File from airflow.models.baseoperator import ( BaseOperator, @@ -40,6 +40,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.models.trigger import TriggerFailureReason from airflow.providers.common.sql.operators import sql from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup @@ -582,6 +583,15 @@ def test_logging_propogated_by_default(self, caplog): # leaking a lot of state) assert caplog.messages == ["test"] + def test_resume_execution(self): + op = BaseOperator(task_id="hi") + with pytest.raises(TaskDeferralTimeout): + op.resume_execution( + next_method="__fail__", + next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT}, + context={}, + ) + def test_deepcopy(): # Test bug when copying an operator attached to a DAG diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index b1d398265a7322..9bb4f5b9934d2c 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -45,6 +45,7 @@ from airflow.executors.local_executor import LocalExecutor from airflow.executors.sequential_executor import SequentialExecutor from airflow.models import TaskInstance, TaskReschedule +from airflow.models.trigger import TriggerFailureReason from airflow.models.xcom import XCom from airflow.operators.empty import EmptyOperator from airflow.providers.celery.executors.celery_executor import CeleryExecutor @@ -1061,6 +1062,15 @@ def test_prepare_for_execution(self, executor_cls_mode): task = sensor.prepare_for_execution() assert task.mode == mode + def test_resume_execution(self): + op = BaseSensorOperator(task_id="hi") + with pytest.raises(AirflowSensorTimeout): + op.resume_execution( + next_method="__fail__", + next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT}, + context={}, + ) + @poke_mode_only class DummyPokeOnlySensor(BaseSensorOperator):