Skip to content

Commit

Permalink
fix(sensors/external_task): respect soft_fail argument when ExternalT…
Browse files Browse the repository at this point in the history
…askSensor runs in deferrable mode
  • Loading branch information
Lee-W committed Aug 14, 2023
1 parent f36ba1d commit d9c385f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
6 changes: 6 additions & 0 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ def reschedule(self):
def get_serialized_fields(cls):
return super().get_serialized_fields() | {"reschedule"}

def raise_failed_or_skiping_exception(self, failed_message: str, skipping_message: str = "") -> None:
"""Raise AirflowSkipException if self.soft_fail is set to True. Otherwise raise AirflowException."""
if self.soft_fail:
raise AirflowSkipException(skipping_message or failed_message)
raise AirflowException(failed_message)


def poke_mode_only(cls):
"""
Expand Down
46 changes: 26 additions & 20 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval

self._skipping_message_postfix = " Skipping due to soft_fail."

def _get_dttm_filter(self, context):
if self.execution_delta:
dttm = context["logical_date"] - self.execution_delta
Expand Down Expand Up @@ -274,32 +276,28 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
# Fail if anything in the list has failed.
if count_failed > 0:
if self.external_task_ids:
if self.soft_fail:
raise AirflowSkipException(
f"Some of the external tasks {self.external_task_ids} "
f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail."
)
raise AirflowException(
failed_message = (
f"Some of the external tasks {self.external_task_ids} "
f"in DAG {self.external_dag_id} failed."
)

self.raise_failed_or_skiping_exception(
failed_message=failed_message,
skipping_message=f"{failed_message}{self._skipping_message_postfix}",
)
elif self.external_task_group_id:
if self.soft_fail:
raise AirflowSkipException(
self.raise_failed_or_skiping_exception(
failed_message=(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail."
f"in DAG '{self.external_dag_id}' failed."
)
raise AirflowException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed."
)

else:
if self.soft_fail:
raise AirflowSkipException(
f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail."
)
raise AirflowException(f"The external DAG {self.external_dag_id} failed.")
failed_message = f"The external DAG {self.external_dag_id} failed."
self.raise_failed_or_skiping_exception(
failed_message=failed_message,
skipping_message=f"{failed_message}{self._skipping_message_postfix}",
)

count_skipped = -1
if self.skipped_states:
Expand Down Expand Up @@ -354,12 +352,20 @@ def execute_complete(self, context, event=None):
self.log.info("External task %s has executed successfully.", self.external_task_id)
return None
elif event["status"] == "timeout":
raise AirflowException("Dag was not started within 1 minute, assuming fail.")
failed_message = "Dag was not started within 1 minute, assuming fail."
self.raise_failed_or_skiping_exception(
failed_message=failed_message,
skipping_message=f"{failed_message}{self._skipping_message_postfix}",
)
else:
raise AirflowException(
failed_message = (
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)
self.raise_failed_or_skiping_exception(
failed_message=failed_message,
skipping_message=f"{failed_message}{self._skipping_message_postfix}",
)

def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)
Expand Down

0 comments on commit d9c385f

Please sign in to comment.