Skip to content

Commit

Permalink
Make PreviousDagRunSensor more robust (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog authored Aug 16, 2024
1 parent 830079e commit bd7dd9b
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions observatory-platform/observatory/platform/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,26 @@
import textwrap
import traceback
from datetime import timedelta
from functools import partial
from pydoc import locate
from typing import List, Union
from typing import Optional
from typing import List, Optional, Union

import pendulum
import six
import validators
from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import TaskInstance, DagBag, Variable, XCom, DagRun
from airflow.models import DagBag, DagRun, TaskInstance, Variable, XCom
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.db import provide_session
from airflow.utils.state import State
from dateutil.relativedelta import relativedelta
from sqlalchemy import and_
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session, Session

from observatory.platform.config import AirflowConns, AirflowVars
from observatory.platform.observatory_config import Workflow, json_string_to_workflows
from observatory.platform.observatory_config import json_string_to_workflows, Workflow

ScheduleInterval = Union[str, timedelta, relativedelta]

Expand Down Expand Up @@ -415,6 +415,7 @@ def __init__(
external_dag_id=dag_id,
external_task_id=external_task_id,
allowed_states=allowed_states,
execution_date_fn=partial(get_previous_dag_runs, dag_id),
*args,
**kwargs,
)
Expand All @@ -432,3 +433,30 @@ def poke(self, context, session=None):
return True

return super().poke(context, session=session)


@provide_session
def get_previous_dag_runs(
dag_id: str, logical_date: pendulum.DateTime, session: scoped_session = None, **context
) -> List[pendulum.DateTime]:
"""Get previous logical dates for a given DAG.
:param dag_id: the DAG ID of the DAG we are waiting for.
:param logical_date: the logical date of the waiting DAG.
:param session: the SQL Alchemy session.
:param context: the Airflow context.
:return: all logical dates of previous DAG runs.
"""

dag_runs = (
session.query(DagRun)
.filter(
DagRun.dag_id == dag_id,
DagRun.execution_date < logical_date,
)
.order_by(DagRun.execution_date.desc())
.all()
)
dates = [d.logical_date for d in dag_runs]

return dates

0 comments on commit bd7dd9b

Please sign in to comment.