Skip to content

Commit

Permalink
Use async db calls in WorkflowTrigger (#38689)
Browse files Browse the repository at this point in the history
* Use async db calls in WorkflowTrigger

* address PR comments

* deprecate TaskStateTrigger with proper category
  • Loading branch information
stevenschaerer authored Apr 4, 2024
1 parent b80e17c commit e6eec0c
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 78 deletions.
46 changes: 25 additions & 21 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from typing import Any

from asgiref.sync import sync_to_async
from deprecated import deprecated
from sqlalchemy import func

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
Expand Down Expand Up @@ -98,44 +100,46 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
while True:
if self.failed_states:
failed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.failed_states,
)
failed_count = await self._get_count(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
skipped_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.skipped_states,
)
skipped_count = await self._get_count(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.allowed_states,
)
allowed_count = await self._get_count(self.allowed_states)
if allowed_count == len(self.execution_dates):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)

@sync_to_async
def _get_count(self, states: typing.Iterable[str] | None) -> int:
"""
Get the count of records against dttm filter and states. Async wrapper for _get_count.
:param states: task or dag states
:return The count of records.
"""
return _get_count(
dttm_filter=self.execution_dates,
external_task_ids=self.external_task_ids,
external_task_group_id=self.external_task_group_id,
external_dag_id=self.external_dag_id,
states=states,
)


@deprecated(
reason="TaskStateTrigger has been deprecated and will be removed in future.",
category=RemovedInAirflow3Warning,
)
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a specific logical date.
Expand Down
2 changes: 1 addition & 1 deletion contributing-docs/testing/unit_tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ For avoid this make sure:
.. code-block:: python
def test_deprecated_argument():
with pytest.warn(AirflowProviderDeprecationWarning, match="expected warning pattern"):
with pytest.warns(AirflowProviderDeprecationWarning, match="expected warning pattern"):
SomeDeprecatedClass(foo="bar", spam="egg")
Expand Down
153 changes: 97 additions & 56 deletions tests/triggers/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import asyncio
import datetime
import time
from unittest import mock

import pytest
from sqlalchemy.exc import SQLAlchemyError

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
Expand All @@ -41,11 +43,10 @@ class TestWorkflowTrigger:
STATES = ["success", "fail"]

@mock.patch("airflow.triggers.external_task._get_count")
@mock.patch("asyncio.sleep")
@pytest.mark.asyncio
async def test_task_workflow_trigger_success(self, mock_sleep, mock_get_count):
async def test_task_workflow_trigger_success(self, mock_get_count):
"""check the db count get called correctly."""
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -54,19 +55,29 @@ async def test_task_workflow_trigger_success(self, mock_sleep, mock_get_count):
poke_interval=0.2,
)

generator = trigger.run()
await generator.asend(None)
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
await generator.__anext__()
await gen.__anext__()

@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_failed(self, mock_get_count):
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -77,13 +88,19 @@ async def test_task_workflow_trigger_failed(self, mock_get_count):

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert trigger_task.done() is True
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "failed"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
Expand All @@ -104,12 +121,16 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
assert trigger_task.done() is True
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
Expand All @@ -118,7 +139,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_skipped(self, mock_get_count):
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -129,13 +150,19 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count):

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert trigger_task.done() is True
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "skipped"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)

@mock.patch("airflow.triggers.external_task._get_count")
Expand All @@ -153,7 +180,7 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
assert trigger_task.done() is True
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
Expand Down Expand Up @@ -222,14 +249,15 @@ async def test_task_state_trigger_success(self, session):
session.add(instance)
session.commit()

trigger = TaskStateTrigger(
dag_id=dag.dag_id,
task_id=instance.task_id,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id=dag.dag_id,
task_id=instance.task_id,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)
Expand All @@ -252,14 +280,15 @@ async def test_task_state_trigger_timeout(self, mock_utcnow):
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=61)

trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0
Expand All @@ -284,14 +313,15 @@ async def test_task_state_trigger_timeout_sleep_success(self, mock_sleep, mock_u
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=20)

trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0
Expand Down Expand Up @@ -331,14 +361,15 @@ async def test_task_state_trigger_failed_exception(self, mock_sleep, mock_utcnow
trigger_start_time + datetime.timedelta(seconds=20),
]

trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.side_effect = [SQLAlchemyError]
Expand All @@ -358,14 +389,15 @@ def test_serialization(self):
and classpath.
"""
trigger_start_time = utcnow()
trigger = TaskStateTrigger(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
trigger_start_time=trigger_start_time,
)
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
trigger_start_time=trigger_start_time,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
assert kwargs == {
Expand Down Expand Up @@ -438,3 +470,12 @@ def test_serialization(self):
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
}


def mocked_get_count(*args, **kwargs):
time.sleep(0.0001)
return 1


async def fake_async_fun():
await asyncio.sleep(0.00005)

0 comments on commit e6eec0c

Please sign in to comment.