Skip to content

Commit

Permalink
Add functions required to make use of Airflow TaskGroups and test them (
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog authored Oct 3, 2023
1 parent b94c968 commit 9a6ad37
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import socket
import socketserver
import threading
import time
import unittest
import uuid
from dataclasses import dataclass
Expand All @@ -82,15 +83,14 @@
import paramiko
import pendulum
import requests
import time
from airflow import DAG, settings
from airflow.exceptions import AirflowException
from airflow.models import DagBag
from airflow.models.connection import Connection
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.variable import Variable
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.empty import EmptyOperator
from airflow.utils import db
from airflow.utils.state import State
from airflow.utils.types import DagRunType
Expand All @@ -103,10 +103,10 @@
from google.cloud import bigquery, storage
from google.cloud.exceptions import NotFound
from pendulum import DateTime
from sftpserver.stub_sftp import StubServer, StubSFTPServer
from pyftpdlib.authorizers import DummyAuthorizer
from pyftpdlib.handlers import FTPHandler
from pyftpdlib.servers import ThreadedFTPServer
from sftpserver.stub_sftp import StubServer, StubSFTPServer

from observatory.api.testing import ObservatoryApiEnvironment
from observatory.platform.bigquery import bq_create_dataset
Expand Down Expand Up @@ -442,6 +442,21 @@ def run_task(self, task_id: str) -> TaskInstance:

return ti

def get_task_instance(self, task_id: str) -> TaskInstance:
"""Get an up-to-date TaskInstance.
:param task_id: the task id.
:return: up-to-date TaskInstance instance.
"""

assert self.dag_run is not None, "with create_dag_run must be called before get_task_instance"

run_id = self.dag_run.run_id
task = self.dag_run.dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=run_id)
ti.refresh_from_db()
return ti

@contextlib.contextmanager
def create_dag_run(
self,
Expand Down Expand Up @@ -1154,7 +1169,7 @@ def make_dummy_dag(dag_id: str, execution_date: pendulum.DateTime) -> DAG:
default_args={"owner": "airflow", "start_date": execution_date},
catchup=False,
) as dag:
task1 = DummyOperator(task_id="dummy_task")
task1 = EmptyOperator(task_id="dummy_task")

return dag

Expand Down
20 changes: 20 additions & 0 deletions observatory-platform/observatory/platform/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def __init__(
"start_date": self.start_date,
"on_failure_callback": on_failure_callback,
"retries": self.max_retries,
"queue": self.queue,
}
self.description = self.__doc__
self.dag = DAG(
Expand Down Expand Up @@ -516,6 +517,25 @@ def add_task(
op = PythonOperator(python_callable=partial(self.task_callable, func), **kwargs_)
self.add_operator(op)

def make_python_operator(
self,
func: Callable,
task_id: str,
**kwargs,
):
"""Make a PythonOperator which is used to process releases.
:param func: the function that will be called by the PythonOperator task.
:param task_id: the task id.
:param kwargs: the context passed from the PythonOperator. See https://airflow.apache.org/docs/stable/macros-ref.html
for a list of the keyword arguments that are passed to this argument.
:return: PythonOperator instance.
"""

kwargs_ = copy.copy(kwargs)
kwargs_["task_id"] = task_id
return PythonOperator(python_callable=partial(self.task_callable, func), **kwargs_)

@contextlib.contextmanager
def parallel_tasks(self):
"""When called, all tasks added to the workflow within the `with` block will run in parallel.
Expand Down

0 comments on commit 9a6ad37

Please sign in to comment.