From d9c385f9ca0d3d0d586a0faea5575d344903d8df Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 8 Aug 2023 17:22:29 +0800 Subject: [PATCH 1/2] fix(sensors/external_task): respect soft_fail argument when ExternalTaskSensor runs in deferrable mode --- airflow/sensors/base.py | 6 +++++ airflow/sensors/external_task.py | 46 ++++++++++++++++++-------------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index e217727785a50..6d0abfd999e8f 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -330,6 +330,12 @@ def reschedule(self): def get_serialized_fields(cls): return super().get_serialized_fields() | {"reschedule"} + def raise_failed_or_skiping_exception(self, failed_message: str, skipping_message: str = "") -> None: + """Raise AirflowSkipException if self.soft_fail is set to True. Otherwise raise AirflowException.""" + if self.soft_fail: + raise AirflowSkipException(skipping_message or failed_message) + raise AirflowException(failed_message) + def poke_mode_only(cls): """ diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 5e42820ffe997..9e48d3e1406ea 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -222,6 +222,8 @@ def __init__( self.deferrable = deferrable self.poll_interval = poll_interval + self._skipping_message_postfix = " Skipping due to soft_fail." + def _get_dttm_filter(self, context): if self.execution_delta: dttm = context["logical_date"] - self.execution_delta @@ -274,32 +276,28 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: # Fail if anything in the list has failed. if count_failed > 0: if self.external_task_ids: - if self.soft_fail: - raise AirflowSkipException( - f"Some of the external tasks {self.external_task_ids} " - f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException( + failed_message = ( f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed." ) + + self.raise_failed_or_skiping_exception( + failed_message=failed_message, + skipping_message=f"{failed_message}{self._skipping_message_postfix}", + ) elif self.external_task_group_id: - if self.soft_fail: - raise AirflowSkipException( + self.raise_failed_or_skiping_exception( + failed_message=( f"The external task_group '{self.external_task_group_id}' " - f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." + f"in DAG '{self.external_dag_id}' failed." ) - raise AirflowException( - f"The external task_group '{self.external_task_group_id}' " - f"in DAG '{self.external_dag_id}' failed." ) - else: - if self.soft_fail: - raise AirflowSkipException( - f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException(f"The external DAG {self.external_dag_id} failed.") + failed_message = f"The external DAG {self.external_dag_id} failed." + self.raise_failed_or_skiping_exception( + failed_message=failed_message, + skipping_message=f"{failed_message}{self._skipping_message_postfix}", + ) count_skipped = -1 if self.skipped_states: @@ -354,12 +352,20 @@ def execute_complete(self, context, event=None): self.log.info("External task %s has executed successfully.", self.external_task_id) return None elif event["status"] == "timeout": - raise AirflowException("Dag was not started within 1 minute, assuming fail.") + failed_message = "Dag was not started within 1 minute, assuming fail." + self.raise_failed_or_skiping_exception( + failed_message=failed_message, + skipping_message=f"{failed_message}{self._skipping_message_postfix}", + ) else: - raise AirflowException( + failed_message = ( "Error occurred while trying to retrieve task status. Please, check the " "name of executed task and Dag." ) + self.raise_failed_or_skiping_exception( + failed_message=failed_message, + skipping_message=f"{failed_message}{self._skipping_message_postfix}", + ) def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) From 1d46a5330247dbc37378acb362e2b6bca1331848 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 14 Aug 2023 12:05:46 +0800 Subject: [PATCH 2/2] feat(sensor/base): make raise_failed_or_skiping_exception arguments keyword-only arguments --- airflow/sensors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 6d0abfd999e8f..e23894a28de12 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -330,7 +330,7 @@ def reschedule(self): def get_serialized_fields(cls): return super().get_serialized_fields() | {"reschedule"} - def raise_failed_or_skiping_exception(self, failed_message: str, skipping_message: str = "") -> None: + def raise_failed_or_skiping_exception(self, *, failed_message: str, skipping_message: str = "") -> None: """Raise AirflowSkipException if self.soft_fail is set to True. Otherwise raise AirflowException.""" if self.soft_fail: raise AirflowSkipException(skipping_message or failed_message)