Skip to content

Commit

Permalink
Respect "soft_fail" argument when "poke" is called (#33401)
Browse files Browse the repository at this point in the history
* feat(sensors/base): raise AirflowSkipException if soft_fail is set to True and exception occurs after running poke()

* test(sensor/base): add test case for respecting soft_fail option when other kinds of exception is raised

(cherry picked from commit d91c481)
  • Loading branch information
Lee-W authored and ephraimbuddy committed Aug 28, 2023
1 parent 6996d30 commit dfe129f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
7 changes: 6 additions & 1 deletion airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,19 @@ def run_duration() -> float:
except (
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowSkipException,
AirflowFailException,
) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
raise e
except AirflowSkipException as e:
raise e
except Exception as e:
if self.silent_fail:
logging.error("Sensor poke failed: \n %s", traceback.format_exc())
poke_return = False
elif self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
else:
raise e

Expand Down
30 changes: 25 additions & 5 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
)
from airflow.executors.debug_executor import DebugExecutor
from airflow.executors.executor_constants import (
Expand All @@ -48,9 +50,7 @@
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from airflow.providers.celery.executors.celery_kubernetes_executor import CeleryKubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import (
LocalKubernetesExecutor,
)
from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import LocalKubernetesExecutor
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
Expand Down Expand Up @@ -176,6 +176,28 @@ def test_soft_fail(self, make_sensor):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.parametrize(
"exception_cls",
(
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowFailException,
Exception,
),
)
def test_soft_fail_with_non_skip_exception(self, make_sensor, exception_cls):
sensor, dr = make_sensor(False, soft_fail=True)
sensor.poke = Mock(side_effect=[exception_cls(None)])

self._run(sensor)
tis = dr.get_task_instances()
assert len(tis) == 2
for ti in tis:
if ti.task_id == SENSOR_OP:
assert ti.state == State.SKIPPED
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

def test_soft_fail_with_retries(self, make_sensor):
sensor, dr = make_sensor(
return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1)
Expand Down Expand Up @@ -518,7 +540,6 @@ def run_duration():
assert sensor._get_next_poke_interval(started_at, run_duration, 2) == sensor.poke_interval

def test_sensor_with_exponential_backoff_on(self):

sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60, exponential_backoff=True
)
Expand Down Expand Up @@ -575,7 +596,6 @@ def run_duration():
assert intervals[0] == intervals[-1]

def test_sensor_with_exponential_backoff_on_and_max_wait(self):

sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
Expand Down

0 comments on commit dfe129f

Please sign in to comment.