Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deferrable sensors can implement sensor timeout #33718

Merged
merged 10 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple

from airflow.utils.trigger_rule import TriggerRule

if TYPE_CHECKING:
import datetime
from collections.abc import Sized

from airflow.models import DagRun
Expand Down Expand Up @@ -385,14 +385,18 @@ def __init__(
trigger,
method_name: str,
kwargs: dict[str, Any] | None = None,
timeout: datetime.timedelta | None = None,
timeout: timedelta | int | float | None = None,
):
super().__init__()
self.trigger = trigger
self.method_name = method_name
self.kwargs = kwargs
self.timeout = timeout
self.timeout: timedelta | None
# Check timeout type at runtime
if isinstance(timeout, (int, float)):
self.timeout = timedelta(seconds=timeout)
else:
self.timeout = timeout
if self.timeout is not None and not hasattr(self.timeout, "total_seconds"):
raise ValueError("Timeout value must be a timedelta")

Expand All @@ -417,6 +421,10 @@ class TaskDeferralError(AirflowException):
"""Raised when a task failed during deferral for some reason."""


class TaskDeferralTimeout(AirflowException):
"""Raise when there is a timeout on the deferral."""


# The try/except handling is needed after we moved all k8s classes to cncf.kubernetes provider
# These two exceptions are used internally by Kubernetes Executor but also by PodGenerator, so we need
# to leave them here in case older version of cncf.kubernetes provider is used to run KubernetesPodOperator
Expand Down
5 changes: 3 additions & 2 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
from airflow.stats import Stats
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.timetables.simple import AssetTriggeredTimetable
Expand Down Expand Up @@ -2057,8 +2058,8 @@ def check_trigger_timeouts(
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method="__fail__",
next_kwargs={"error": "Trigger/execution timeout"},
next_method=TRIGGER_FAIL_REPR,
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
trigger_id=None,
)
).rowcount
Expand Down
11 changes: 8 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from airflow.exceptions import (
AirflowException,
TaskDeferralError,
TaskDeferralTimeout,
TaskDeferred,
)
from airflow.lineage import apply_lineage, prepare_lineage
Expand All @@ -72,6 +73,7 @@
from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
from airflow.sdk.definitions.baseoperator import (
BaseOperatorMeta as TaskSDKBaseOperatorMeta,
get_merged_defaults,
Expand Down Expand Up @@ -973,7 +975,7 @@ def defer(
trigger: BaseTrigger,
method_name: str,
kwargs: dict[str, Any] | None = None,
timeout: timedelta | None = None,
timeout: timedelta | int | float | None = None,
) -> NoReturn:
"""
Mark this Operator "deferred", suspending its execution until the provided trigger fires an event.
Expand All @@ -990,12 +992,15 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None,
"""Call this method when a deferred task is resumed."""
# __fail__ is a special signal value for next_method that indicates
# this task was scheduled specifically to fail.
if next_method == "__fail__":
if next_method == TRIGGER_FAIL_REPR:
next_kwargs = next_kwargs or {}
traceback = next_kwargs.get("traceback")
if traceback is not None:
self.log.error("Trigger failed:\n%s", "\n".join(traceback))
raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT:
raise TaskDeferralTimeout(error)
else:
raise TaskDeferralError(error)
# Grab the callable off the Operator/Task and add in any kwargs
execute_callable = getattr(self, next_method)
if next_kwargs:
Expand Down
1 change: 1 addition & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,7 @@ def _defer_task(
) -> TaskInstance:
from airflow.models.trigger import Trigger

timeout: timedelta | None
if exception is not None:
trigger_row = Trigger.from_object(exception.trigger)
next_method = exception.method_name
Expand Down
29 changes: 27 additions & 2 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
from collections.abc import Iterable
from enum import Enum
from traceback import format_exception
from typing import TYPE_CHECKING, Any

Expand All @@ -40,6 +41,27 @@

from airflow.triggers.base import BaseTrigger

TRIGGER_FAIL_REPR = "__fail__"
"""String value to represent trigger failure.
Internal use only.
:meta private:
"""


class TriggerFailureReason(str, Enum):
"""
Reasons for trigger failures.
Internal use only.
:meta private:
"""

TRIGGER_TIMEOUT = "Trigger timeout"
TRIGGER_FAILURE = "Trigger failure"


class Trigger(Base):
"""
Expand Down Expand Up @@ -229,8 +251,11 @@ def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) ->
):
# Add the error and set the next_method to the fail state
traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None
task_instance.next_method = "__fail__"
task_instance.next_kwargs = {"error": "Trigger failure", "traceback": traceback}
task_instance.next_method = TRIGGER_FAIL_REPR
task_instance.next_kwargs = {
"error": TriggerFailureReason.TRIGGER_FAILURE,
"traceback": traceback,
}
# Remove ourselves as its trigger
task_instance.trigger_id = None
# Finally, mark it as scheduled so it gets re-queued
Expand Down
5 changes: 4 additions & 1 deletion airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
AirflowSkipException,
AirflowTaskTimeout,
TaskDeferralError,
TaskDeferralTimeout,
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -174,7 +175,7 @@ def __init__(
super().__init__(**kwargs)
self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds()
self.soft_fail = soft_fail
self.timeout = self._coerce_timeout(timeout).total_seconds()
self.timeout: int | float = self._coerce_timeout(timeout).total_seconds()
self.mode = mode
self.exponential_backoff = exponential_backoff
self.max_wait = self._coerce_max_wait(max_wait)
Expand Down Expand Up @@ -338,6 +339,8 @@ def run_duration() -> float:
def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
try:
return super().resume_execution(next_method, next_kwargs, context)
except TaskDeferralTimeout as e:
raise AirflowSensorTimeout(*e.args) from e
except (AirflowException, TaskDeferralError) as e:
if self.soft_fail:
raise AirflowSkipException(str(e)) from e
Expand Down
21 changes: 20 additions & 1 deletion providers/src/airflow/providers/standard/sensors/time_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from time import sleep
from typing import TYPE_CHECKING, Any, NoReturn

from packaging.version import Version

from airflow.configuration import conf
from airflow.exceptions import AirflowSkipException
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
Expand All @@ -32,6 +34,12 @@
from airflow.utils.context import Context


def _get_airflow_version():
from airflow import __version__ as airflow_version

return Version(Version(airflow_version).base_version)


class TimeDeltaSensor(BaseSensorOperator):
"""
Waits for a timedelta after the run's data interval.
Expand Down Expand Up @@ -91,7 +99,18 @@ def execute(self, context: Context) -> bool | NoReturn:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise

self.defer(trigger=trigger, method_name="execute_complete")
# todo: remove backcompat when min airflow version greater than 2.11
timeout: int | float | timedelta
if _get_airflow_version() >= Version("2.11.0"):
timeout = self.timeout
else:
timeout = timedelta(seconds=self.timeout)

self.defer(
trigger=trigger,
method_name="execute_complete",
timeout=timeout,
)

def execute_complete(self, context: Context, event: Any = None) -> None:
"""Handle the event when the trigger fires and return immediately."""
Expand Down
12 changes: 11 additions & 1 deletion tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import pytest

from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferralTimeout
from airflow.lineage.entities import File
from airflow.models.baseoperator import (
BaseOperator,
Expand All @@ -40,6 +40,7 @@
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.trigger import TriggerFailureReason
from airflow.providers.common.sql.operators import sql
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -582,6 +583,15 @@ def test_logging_propogated_by_default(self, caplog):
# leaking a lot of state)
assert caplog.messages == ["test"]

def test_resume_execution(self):
op = BaseOperator(task_id="hi")
with pytest.raises(TaskDeferralTimeout):
op.resume_execution(
next_method="__fail__",
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
context={},
)


def test_deepcopy():
# Test bug when copying an operator attached to a DAG
Expand Down
10 changes: 10 additions & 0 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.models import TaskInstance, TaskReschedule
from airflow.models.trigger import TriggerFailureReason
from airflow.models.xcom import XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
Expand Down Expand Up @@ -1061,6 +1062,15 @@ def test_prepare_for_execution(self, executor_cls_mode):
task = sensor.prepare_for_execution()
assert task.mode == mode

def test_resume_execution(self):
op = BaseSensorOperator(task_id="hi")
with pytest.raises(AirflowSensorTimeout):
op.resume_execution(
next_method="__fail__",
next_kwargs={"error": TriggerFailureReason.TRIGGER_TIMEOUT},
context={},
)


@poke_mode_only
class DummyPokeOnlySensor(BaseSensorOperator):
Expand Down