Skip to content

Commit

Permalink
fixup! Make dag_version_id required for TI
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Feb 12, 2025
1 parent eb69cd9 commit d8aebe6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
22 changes: 17 additions & 5 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DagBag
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
Expand Down Expand Up @@ -303,6 +305,8 @@ def test_heartbeat_failed_fast(self):
dag = self.dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
dr = dag.create_dagrun(
run_id="test_heartbeat_failed_fast_run",
run_type=DagRunType.MANUAL,
Expand Down Expand Up @@ -372,7 +376,8 @@ def test_localtaskjob_double_trigger(self):
dag = self.dagbag.dags.get("test_localtaskjob_double_trigger")
task = dag.get_task("test_localtaskjob_double_trigger_task")
data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)

dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
session = settings.Session()

dag.clear()
Expand Down Expand Up @@ -414,8 +419,9 @@ def test_localtaskjob_double_trigger(self):
def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, create_dummy_dag):
dag, task = create_dummy_dag("test_localtaskjob_code")
dag_run = dag.get_last_dagrun()

ti_run = TaskInstance(task=task, run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
ti_run = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id)
ti_run.refresh_from_db()
job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run)
Expand Down Expand Up @@ -445,8 +451,10 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code,
def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create_dummy_dag):
dag, task = create_dummy_dag("test_localtaskjob_double_trigger")
dag_run = dag.get_last_dagrun()
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version

ti_run = TaskInstance(task=task, run_id=dag_run.run_id)
ti_run = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id)
ti_run.refresh_from_db()
job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor())
job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run)
Expand Down Expand Up @@ -1011,6 +1019,8 @@ def task_function():
logger.addHandler(tmpfile_handler)

data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
state=State.RUNNING,
Expand All @@ -1020,7 +1030,9 @@ def task_function():
run_after=DEFAULT_LOGICAL_DATE,
triggered_by=DagRunTriggeredByType.TEST,
)
ti = TaskInstance(task=task, run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id)
ti.refresh_from_db()
job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id)
job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True)
Expand Down
9 changes: 3 additions & 6 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import Asset
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.session import create_session, provide_session
Expand Down Expand Up @@ -520,9 +519,8 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker)
dag_id = "SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute"
task_id_1 = "dummy_task"

with dag_maker(dag_id=dag_id, session=session) as dag:
with dag_maker(dag_id=dag_id, session=session):
EmptyOperator(task_id=task_id_1)
assert isinstance(dag, SerializedDAG)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
Expand Down Expand Up @@ -5542,10 +5540,10 @@ def test_zombie_message(self, session, create_dagrun):

# We will provision 2 tasks so we can check we only find zombies from this scheduler
tasks_to_setup = ["branching", "run_this_first"]

dag_version = DagVersion.get_latest_version(dag.dag_id)
for task_id in tasks_to_setup:
task = dag.get_task(task_id=task_id)
ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING)
ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING, dag_version_id=dag_version.id)
ti.queued_by_job_id = 999

session.add(ti)
Expand Down Expand Up @@ -5875,7 +5873,6 @@ def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle):
session.flush()

dag.catchup = False
DAG.bulk_write_to_db("testing", None, [dag])
assert not dag.catchup

dm = DagModel.get_dagmodel(dag.dag_id)
Expand Down
26 changes: 21 additions & 5 deletions tests/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from airflow.models import DagModel, DagRun, TaskInstance, Trigger
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
Expand Down Expand Up @@ -108,9 +110,11 @@ def create_trigger_in_db(session, trigger, operator=None):
operator.dag = dag
else:
operator = BaseOperator(task_id="test_ti", dag=dag)
task_instance = TaskInstance(operator, run_id=run.run_id)
task_instance.trigger_id = trigger_orm.id
session.add(dag_model)
SerializedDagModel.write_dag(dag, bundle_name="testing")
dag_version = DagVersion.get_latest_version(dag.dag_id)
task_instance = TaskInstance(operator, run_id=run.run_id, dag_version_id=dag_version.id)
task_instance.trigger_id = trigger_orm.id
session.add(run)
session.add(trigger_orm)
session.add(task_instance)
Expand Down Expand Up @@ -341,15 +345,19 @@ async def test_trigger_create_race_condition_38599(session, tmp_path):
session.add(trigger_orm)

dag = DagModel(dag_id="test-dag")
session.add(dag)
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
SerializedDagModel.write_dag(DAG(dag_id=dag.dag_id), bundle_name="testing")
dag_version = DagVersion.get_latest_version(dag.dag_id)
ti = TaskInstance(
PythonOperator(task_id="dummy-task", python_callable=print),
run_id=dag_run.run_id,
state=TaskInstanceState.DEFERRED,
dag_version_id=dag_version.id,
)
ti.dag_id = dag.dag_id
ti.trigger_id = 1
session.add(dag)

session.add(dag_run)
session.add(ti)

Expand Down Expand Up @@ -473,11 +481,19 @@ def handle_events(self):
session.add(trigger_orm)

dag = DagModel(dag_id="test-dag")
session.add(dag)
SerializedDagModel.write_dag(DAG(dag_id=dag.dag_id), bundle_name="testing")
dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none")
ti = TaskInstance(PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
ti = TaskInstance(
PythonOperator(task_id="dummy-task", python_callable=print),
run_id=dag_run.run_id,
dag_version_id=dag_version.id,
)
ti.dag_id = dag.dag_id
ti.trigger_id = 1
session.add(dag)

session.add(dag_run)
session.add(ti)

Expand Down
49 changes: 49 additions & 0 deletions tests/models/test_cleartasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,55 @@ def test_clear_task_instances_without_task(self, dag_maker):
assert ti1.try_number == 1
assert ti1.max_tries == 2

@pytest.mark.xfail(
reason="Remove this and update clear_task_instances as tasks cannot exists without DAG"
)
def test_clear_task_instances_without_dag(self, dag_maker):
# Don't write DAG to the database, so no DAG is found by clear_task_instances().
with dag_maker(
"test_clear_task_instances_without_dag",
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
) as dag:
task0 = EmptyOperator(task_id="task0")
task1 = EmptyOperator(task_id="task1", retries=2)

dr = dag_maker.create_dagrun(
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)

ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
ti0.refresh_from_task(task0)
ti1.refresh_from_task(task1)

with create_session() as session:
# do the incrementing of try_number ordinarily handled by scheduler
ti0.try_number += 1
ti1.try_number += 1
session.merge(ti0)
session.merge(ti1)
session.commit()

ti0.run()
ti1.run()

with create_session() as session:
# we use order_by(task_id) here because for the test DAG structure of ours
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all()
clear_task_instances(qry, session)

# When no DAG is found, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
ti1.refresh_from_db()
assert ti0.try_number == 1
assert ti0.max_tries == 1
assert ti1.try_number == 1
assert ti1.max_tries == 2

def test_clear_task_instances_without_dag_param(self, dag_maker, session):
with dag_maker(
"test_clear_task_instances_without_dag_param",
Expand Down

0 comments on commit d8aebe6

Please sign in to comment.