From d8aebe678554f8814e10d2f286892d48f63c0af3 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 11 Feb 2025 20:45:06 +0100 Subject: [PATCH] fixup! Make dag_version_id required for TI --- tests/jobs/test_local_task_job.py | 22 ++++++++++---- tests/jobs/test_scheduler_job.py | 9 ++---- tests/jobs/test_triggerer_job.py | 26 ++++++++++++---- tests/models/test_cleartasks.py | 49 +++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 16 deletions(-) diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 4bc128f9fe3e52..9b06bc494d952e 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -39,7 +39,9 @@ from airflow.jobs.local_task_job_runner import SIGSEGV_MESSAGE, LocalTaskJobRunner from airflow.listeners.listener import get_listener_manager from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DagBag +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator @@ -303,6 +305,8 @@ def test_heartbeat_failed_fast(self): dag = self.dagbag.get_dag(dag_id) task = dag.get_task(task_id) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dr = dag.create_dagrun( run_id="test_heartbeat_failed_fast_run", run_type=DagRunType.MANUAL, @@ -372,7 +376,8 @@ def test_localtaskjob_double_trigger(self): dag = self.dagbag.dags.get("test_localtaskjob_double_trigger") task = dag.get_task("test_localtaskjob_double_trigger_task") data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) - + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") session = settings.Session() dag.clear() @@ -414,8 +419,9 @@ def test_localtaskjob_double_trigger(self): def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, create_dummy_dag): dag, task = create_dummy_dag("test_localtaskjob_code") dag_run = dag.get_last_dagrun() - - ti_run = TaskInstance(task=task, run_id=dag_run.run_id) + dag_version = DagVersion.get_latest_version(dag.dag_id) + assert dag_version + ti_run = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) ti_run.refresh_from_db() job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor()) job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run) @@ -445,8 +451,10 @@ def test_local_task_return_code_metric(self, mock_stats_incr, mock_return_code, def test_localtaskjob_maintain_heart_rate(self, mock_return_code, caplog, create_dummy_dag): dag, task = create_dummy_dag("test_localtaskjob_double_trigger") dag_run = dag.get_last_dagrun() + dag_version = DagVersion.get_latest_version(dag.dag_id) + assert dag_version - ti_run = TaskInstance(task=task, run_id=dag_run.run_id) + ti_run = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) ti_run.refresh_from_db() job1 = Job(dag_id=ti_run.dag_id, executor=SequentialExecutor()) job_runner = LocalTaskJobRunner(job=job1, task_instance=ti_run) @@ -1011,6 +1019,8 @@ def task_function(): logger.addHandler(tmpfile_handler) data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) + dag.sync_to_db() + SerializedDagModel.write_dag(dag, bundle_name="testing") dag_run = dag.create_dagrun( run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -1020,7 +1030,9 @@ def task_function(): run_after=DEFAULT_LOGICAL_DATE, triggered_by=DagRunTriggeredByType.TEST, ) - ti = TaskInstance(task=task, run_id=dag_run.run_id) + dag_version = DagVersion.get_latest_version(dag.dag_id) + assert dag_version + ti = TaskInstance(task=task, run_id=dag_run.run_id, dag_version_id=dag_version.id) ti.refresh_from_db() job = Job(executor=SequentialExecutor(), dag_id=ti.dag_id) job_runner = LocalTaskJobRunner(job=job, task_instance=ti, ignore_ti_state=True) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index d266f5ced21011..ee2656cc68c255 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -64,7 +64,6 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset -from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.session import create_session, provide_session @@ -520,9 +519,8 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) dag_id = "SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute" task_id_1 = "dummy_task" - with dag_maker(dag_id=dag_id, session=session) as dag: + with dag_maker(dag_id=dag_id, session=session): EmptyOperator(task_id=task_id_1) - assert isinstance(dag, SerializedDAG) scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job) @@ -5542,10 +5540,10 @@ def test_zombie_message(self, session, create_dagrun): # We will provision 2 tasks so we can check we only find zombies from this scheduler tasks_to_setup = ["branching", "run_this_first"] - + dag_version = DagVersion.get_latest_version(dag.dag_id) for task_id in tasks_to_setup: task = dag.get_task(task_id=task_id) - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING) + ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING, dag_version_id=dag_version.id) ti.queued_by_job_id = 999 session.add(ti) @@ -5875,7 +5873,6 @@ def test_catchup_works_correctly(self, dag_maker, testing_dag_bundle): session.flush() dag.catchup = False - DAG.bulk_write_to_db("testing", None, [dag]) assert not dag.catchup dm = DagModel.get_dagmodel(dag.dag_id) diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index 89f6947fe17adb..5764fca9e69748 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -35,6 +35,8 @@ from airflow.models import DagModel, DagRun, TaskInstance, Trigger from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.dag_version import DagVersion +from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger @@ -108,9 +110,11 @@ def create_trigger_in_db(session, trigger, operator=None): operator.dag = dag else: operator = BaseOperator(task_id="test_ti", dag=dag) - task_instance = TaskInstance(operator, run_id=run.run_id) - task_instance.trigger_id = trigger_orm.id session.add(dag_model) + SerializedDagModel.write_dag(dag, bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) + task_instance = TaskInstance(operator, run_id=run.run_id, dag_version_id=dag_version.id) + task_instance.trigger_id = trigger_orm.id session.add(run) session.add(trigger_orm) session.add(task_instance) @@ -341,15 +345,19 @@ async def test_trigger_create_race_condition_38599(session, tmp_path): session.add(trigger_orm) dag = DagModel(dag_id="test-dag") + session.add(dag) dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none") + SerializedDagModel.write_dag(DAG(dag_id=dag.dag_id), bundle_name="testing") + dag_version = DagVersion.get_latest_version(dag.dag_id) ti = TaskInstance( PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id, state=TaskInstanceState.DEFERRED, + dag_version_id=dag_version.id, ) ti.dag_id = dag.dag_id ti.trigger_id = 1 - session.add(dag) + session.add(dag_run) session.add(ti) @@ -473,11 +481,19 @@ def handle_events(self): session.add(trigger_orm) dag = DagModel(dag_id="test-dag") + session.add(dag) + SerializedDagModel.write_dag(DAG(dag_id=dag.dag_id), bundle_name="testing") dag_run = DagRun(dag.dag_id, run_id="abc", run_type="none") - ti = TaskInstance(PythonOperator(task_id="dummy-task", python_callable=print), run_id=dag_run.run_id) + dag_version = DagVersion.get_latest_version(dag.dag_id) + assert dag_version + ti = TaskInstance( + PythonOperator(task_id="dummy-task", python_callable=print), + run_id=dag_run.run_id, + dag_version_id=dag_version.id, + ) ti.dag_id = dag.dag_id ti.trigger_id = 1 - session.add(dag) + session.add(dag_run) session.add(ti) diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index 9f23cf29613b95..e1dd5b84c3bfeb 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -324,6 +324,55 @@ def test_clear_task_instances_without_task(self, dag_maker): assert ti1.try_number == 1 assert ti1.max_tries == 2 + @pytest.mark.xfail( + reason="Remove this and update clear_task_instances as tasks cannot exists without DAG" + ) + def test_clear_task_instances_without_dag(self, dag_maker): + # Don't write DAG to the database, so no DAG is found by clear_task_instances(). + with dag_maker( + "test_clear_task_instances_without_dag", + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) as dag: + task0 = EmptyOperator(task_id="task0") + task1 = EmptyOperator(task_id="task1", retries=2) + + dr = dag_maker.create_dagrun( + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + + ti0, ti1 = sorted(dr.task_instances, key=lambda ti: ti.task_id) + ti0.refresh_from_task(task0) + ti1.refresh_from_task(task1) + + with create_session() as session: + # do the incrementing of try_number ordinarily handled by scheduler + ti0.try_number += 1 + ti1.try_number += 1 + session.merge(ti0) + session.merge(ti1) + session.commit() + + ti0.run() + ti1.run() + + with create_session() as session: + # we use order_by(task_id) here because for the test DAG structure of ours + # this is equivalent to topological sort. It would not work in general case + # but it works for our case because we specifically constructed test DAGS + # in the way that those two sort methods are equivalent + qry = session.query(TI).filter(TI.dag_id == dag.dag_id).order_by(TI.task_id).all() + clear_task_instances(qry, session) + + # When no DAG is found, max_tries will be maximum of original max_tries or try_number. + ti0.refresh_from_db() + ti1.refresh_from_db() + assert ti0.try_number == 1 + assert ti0.max_tries == 1 + assert ti1.try_number == 1 + assert ti1.max_tries == 2 + def test_clear_task_instances_without_dag_param(self, dag_maker, session): with dag_maker( "test_clear_task_instances_without_dag_param",