Skip to content

Commit

Permalink
fix test_external_task test
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan committed Nov 22, 2024
1 parent 789b746 commit 97368da
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 64 deletions.
94 changes: 36 additions & 58 deletions providers/tests/standard/triggers/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,11 @@

from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS, AIRFLOW_V_3_0_PLUS


def _workflow_trigger(external_dag_id, dates, external_task_ids, allowed_states, poke_interval):
if AIRFLOW_V_3_0_PLUS:
return WorkflowTrigger(
external_dag_id=external_dag_id,
logical_dates=dates,
external_task_ids=external_task_ids,
allowed_states=allowed_states,
poke_interval=poke_interval,
)
else:
return WorkflowTrigger(
external_dag_id=external_dag_id,
execution_dates=dates,
external_task_ids=external_task_ids,
allowed_states=allowed_states,
poke_interval=poke_interval,
)
_DATES = (
{"logical_dates": [timezone.datetime(2022, 1, 1)]}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": [timezone.datetime(2022, 1, 1)]}
)


@pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+")
Expand All @@ -64,9 +51,10 @@ class TestWorkflowTrigger:
async def test_task_workflow_trigger_success(self, mock_get_count):
"""check the db count get called correctly."""
mock_get_count.side_effect = mocked_get_count
trigger = _workflow_trigger(

trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
allowed_states=self.STATES,
poke_interval=0.2,
Expand Down Expand Up @@ -97,9 +85,9 @@ async def test_task_workflow_trigger_success(self, mock_get_count):
async def test_task_workflow_trigger_failed(self, mock_get_count):
mock_get_count.side_effect = mocked_get_count

trigger = _workflow_trigger(
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
failed_states=self.STATES,
poke_interval=0.2,
Expand Down Expand Up @@ -129,11 +117,12 @@ async def test_task_workflow_trigger_failed(self, mock_get_count):
@pytest.mark.asyncio
async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
mock_get_count.return_value = 0
trigger = _workflow_trigger(

trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
allowed_states=self.STATES,
failed_states=self.STATES,
poke_interval=0.2,
)

Expand All @@ -160,9 +149,10 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
@pytest.mark.asyncio
async def test_task_workflow_trigger_skipped(self, mock_get_count):
mock_get_count.side_effect = mocked_get_count
trigger = _workflow_trigger(

trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
skipped_states=self.STATES,
poke_interval=0.2,
Expand Down Expand Up @@ -190,13 +180,14 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count):
@pytest.mark.asyncio
async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_count):
mock_get_count.side_effect = [0, 1]
trigger = _workflow_trigger(

trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
allowed_states=self.STATES,
poke_interval=0.2,
)

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
Expand All @@ -218,24 +209,18 @@ def test_serialization(self):
"""
Asserts that the WorkflowTrigger correctly serializes its arguments and classpath.
"""
trigger = _workflow_trigger(
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
dates=[timezone.datetime(2022, 1, 1)],
**_DATES,
external_task_ids=[self.TASK_ID],
allowed_states=self.STATES,
poke_interval=5,
)

classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.standard.triggers.external_task.WorkflowTrigger"
_dates = (
{"logical_dates": [timezone.datetime(2022, 1, 1)]}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": [timezone.datetime(2022, 1, 1)]}
)
assert kwargs == {
"external_dag_id": self.DAG_ID,
**_dates,
**_DATES,
"external_task_ids": [self.TASK_ID],
"external_task_group_id": None,
"failed_states": None,
Expand All @@ -260,24 +245,24 @@ async def test_dag_state_trigger(self, session):
reaches an allowed state (i.e. SUCCESS).
"""
dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1))
logical_date_or_execution_date = (
{"logical_date": timezone.datetime(2022, 1, 1)}
if AIRFLOW_V_3_0_PLUS
else {"execution_date": timezone.datetime(2022, 1, 1)}
)
dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
logical_date=timezone.datetime(2022, 1, 1),
**logical_date_or_execution_date,
run_id=self.RUN_ID,
)
session.add(dag_run)
session.commit()
_dates = (
{"logical_dates": [timezone.datetime(2022, 1, 1)]}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": [timezone.datetime(2022, 1, 1)]}
)

trigger = DagStateTrigger(
dag_id=dag.dag_id,
states=self.STATES,
**_dates,
**_DATES,
poll_interval=0.2,
)

Expand All @@ -298,28 +283,21 @@ async def test_dag_state_trigger(self, session):

def test_serialization(self):
"""Asserts that the DagStateTrigger correctly serializes its arguments and classpath."""
_dates = (
{"logical_dates": [timezone.datetime(2022, 1, 1)]}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": [timezone.datetime(2022, 1, 1)]}
)
trigger = DagStateTrigger(
dag_id=self.DAG_ID,
states=self.STATES,
**_dates,
**_DATES,
poll_interval=5,
)
expected_kwargs = {
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.standard.triggers.external_task.DagStateTrigger"
assert kwargs == {
"dag_id": self.DAG_ID,
"states": self.STATES,
**_dates,
**_DATES,
"poll_interval": 5,
}

classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.standard.triggers.external_task.DagStateTrigger"
assert kwargs == expected_kwargs


def mocked_get_count(*args, **kwargs):
time.sleep(0.0001)
Expand Down
8 changes: 2 additions & 6 deletions providers/tests/standard/triggers/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def test_timedelta_trigger_serialization():
assert -2 < (kwargs["moment"] - expected_moment).total_seconds() < 2


pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Only for Airflow 2.10+")


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Only for Airflow 2.10+")
@pytest.mark.parametrize(
"tz, end_from_trigger",
[
Expand Down Expand Up @@ -117,9 +115,7 @@ async def test_datetime_trigger_timing_airflow_2_10_plus(tz, end_from_trigger):
assert result.payload == expected_payload


pytest.mark.skipif(AIRFLOW_V_2_10_PLUS, reason="Only for Airflow < 2.10+")


@pytest.mark.skipif(AIRFLOW_V_2_10_PLUS, reason="Only for Airflow < 2.10+")
@pytest.mark.parametrize(
"tz",
[
Expand Down

0 comments on commit 97368da

Please sign in to comment.