From bd7dd9bd8b2e0ccf988812ed63dbfca94dc9a410 Mon Sep 17 00:00:00 2001 From: Jamie Diprose <5715104+jdddog@users.noreply.github.com> Date: Fri, 16 Aug 2024 12:20:29 +1200 Subject: [PATCH] Make PreviousDagRunSensor more robust (#663) --- .../observatory/platform/airflow.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/observatory-platform/observatory/platform/airflow.py b/observatory-platform/observatory/platform/airflow.py index 6082b8c53..609ed77f1 100644 --- a/observatory-platform/observatory/platform/airflow.py +++ b/observatory-platform/observatory/platform/airflow.py @@ -21,26 +21,26 @@ import textwrap import traceback from datetime import timedelta +from functools import partial from pydoc import locate -from typing import List, Union -from typing import Optional +from typing import List, Optional, Union import pendulum import six import validators from airflow import AirflowException from airflow.hooks.base import BaseHook -from airflow.models import TaskInstance, DagBag, Variable, XCom, DagRun +from airflow.models import DagBag, DagRun, TaskInstance, Variable, XCom from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook from airflow.sensors.external_task import ExternalTaskSensor from airflow.utils.db import provide_session from airflow.utils.state import State from dateutil.relativedelta import relativedelta from sqlalchemy import and_ -from sqlalchemy.orm import Session +from sqlalchemy.orm import scoped_session, Session from observatory.platform.config import AirflowConns, AirflowVars -from observatory.platform.observatory_config import Workflow, json_string_to_workflows +from observatory.platform.observatory_config import json_string_to_workflows, Workflow ScheduleInterval = Union[str, timedelta, relativedelta] @@ -415,6 +415,7 @@ def __init__( external_dag_id=dag_id, external_task_id=external_task_id, allowed_states=allowed_states, + execution_date_fn=partial(get_previous_dag_runs, dag_id), *args, **kwargs, ) @@ -432,3 +433,30 @@ def poke(self, context, session=None): return True return super().poke(context, session=session) + + +@provide_session +def get_previous_dag_runs( + dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context +) -> List[pendulum.DateTime]: + """Get previous logical dates for a given DAG. + + :param dag_id: the DAG ID of the DAG we are waiting for. + :param logical_date: the logical date of the waiting DAG. + :param session: the SQL Alchemy session. + :param context: the Airflow context. + :return: all logical dates of previous DAG runs. + """ + + dag_runs = ( + session.query(DagRun) + .filter( + DagRun.dag_id == dag_id, + DagRun.execution_date < logical_date, + ) + .order_by(DagRun.execution_date.desc()) + .all() + ) + dates = [d.logical_date for d in dag_runs] + + return dates