Skip to content

Commit

Permalink
fixup! fixup! Make dag_version_id required for TI
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Feb 12, 2025
1 parent d8aebe6 commit 48376ac
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 48 deletions.
6 changes: 5 additions & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models import TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskReturnCode
from airflow.sdk.definitions.param import ParamsDict
Expand Down Expand Up @@ -209,7 +210,10 @@ def _get_ti(
f"run_id or logical_date of {logical_date_or_run_id!r} not found"
)
# TODO: Validate map_index is in range?
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index, dag_version_id=dag_version.id)
if dag_run in session:
session.add(ti)
ti.dag_run = dag_run
Expand Down
9 changes: 7 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy_utils import UUIDType

from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG as SchedulerDAG
Expand Down Expand Up @@ -215,7 +216,9 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
def expand_mapped_task(
self, run_id: str, *, dag_version_id: UUIDType, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down Expand Up @@ -324,7 +327,9 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
ti = TaskInstance(
self, run_id=run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
11 changes: 9 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
NotMapped,
)
from airflow.models.base import _sentinel
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
Expand Down Expand Up @@ -623,7 +624,10 @@ def run(
DagRun.logical_date == info.logical_date,
)
).one()
ti = TaskInstance(self, run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dag_run.run_id, dag_version_id=dag_version.id)
except NoResultFound:
# This is _mostly_ only used in tests
dr = DagRun(
Expand All @@ -640,7 +644,10 @@ def run(
triggered_by=DagRunTriggeredByType.TEST,
state=DagRunState.RUNNING,
)
ti = TaskInstance(self, run_id=dr.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dr.run_id, dag_version_id=dag_version.id)
ti.dag_run = dr
session.add(dr)
session.flush()
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _create_orm_dagrun(
if not dag_version:
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
if not dag_version:
raise AirflowException(f"Could not find a version for DAG {dag.dag_id}")
raise AirflowException(f"Could not find a version for DAG: {dag.dag_id}")
run.verify_integrity(dag_version_id=dag_version.id, session=session)
return run

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int, dag_version_id:
:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)

return {
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/taskmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy_utils import UUIDType

from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -122,7 +123,7 @@ def variant(self) -> TaskMapVariant:

@classmethod
def expand_mapped_task(
cls, dag_version_id, task, run_id: str, *, session: Session
cls, dag_version_id: UUIDType, task, run_id: str, *, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down
38 changes: 31 additions & 7 deletions tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, TaskInstance, Trigger
from airflow.models.dag_version import DagVersion
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -100,6 +102,8 @@ def create_task_instances(
"""Method to create task instances using kwargs and default arguments"""

dag = self.dagbag.get_dag(dag_id)
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
tasks = dag.tasks
counter = len(tasks)
if task_instances is not None:
Expand All @@ -110,6 +114,8 @@ def create_task_instances(
dr = None

tis = []
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
for i in range(counter):
if task_instances is None:
pass
Expand All @@ -132,7 +138,8 @@ def create_task_instances(
state=dag_run_state,
)
session.add(dr)
ti = TaskInstance(task=tasks[i], **self.ti_init)

ti = TaskInstance(task=tasks[i], **self.ti_init, dag_version_id=dag_version.id)
session.add(ti)
ti.dag_run = dr
ti.note = "placeholder-note"
Expand Down Expand Up @@ -366,7 +373,9 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
tis = self.create_task_instances(session)
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance(
task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=tis[0].dag_version_id
)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]:
setattr(ti, attr, getattr(old_ti, attr))
Expand Down Expand Up @@ -2112,7 +2121,9 @@ def test_should_update_mapped_task_instance_state(self, session):
NEW_STATE = "failed"
map_index = 1
tis = self.create_task_instances(session)
ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, map_index=map_index)
ti = TaskInstance(
task=tis[0].task, run_id=tis[0].run_id, map_index=map_index, dag_version_id=tis[0].dag_version_id
)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
session.add(ti)
session.commit()
Expand Down Expand Up @@ -2334,7 +2345,9 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session):
tis = self.create_task_instances(session)
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance(
task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id
)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]:
setattr(ti, attr, getattr(old_ti, attr))
Expand Down Expand Up @@ -2519,7 +2532,13 @@ def test_should_respond_dependencies_mapped(self, session):
)
old_ti = tis[0]

ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=0, state=old_ti.state)
ti = TaskInstance(
task=old_ti.task,
run_id=old_ti.run_id,
map_index=0,
state=old_ti.state,
dag_version_id=old_ti.dag_version_id,
)
session.add(ti)
session.commit()

Expand Down Expand Up @@ -2645,7 +2664,9 @@ def test_should_respond_200_with_mapped_task_at_different_try_numbers(self, try_
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance(
task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=old_ti.dag_version_id
)
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
ti.try_number = 1
for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]:
Expand Down Expand Up @@ -2861,9 +2882,12 @@ def test_ti_in_retry_state_not_returned(self, session):

def test_mapped_task_should_respond_200(self, session):
tis = self.create_task_instances(session, task_instances=[{"state": State.FAILED}])
dag_version_id = tis[0].dag_version_id
old_ti = tis[0]
for idx in (1, 2):
ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx)
ti = TaskInstance(
task=old_ti.task, run_id=old_ti.run_id, map_index=idx, dag_version_id=dag_version_id
)
ti.try_number = 1
session.add(ti)
session.commit()
Expand Down
43 changes: 37 additions & 6 deletions tests/api_connexion/endpoints/test_xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

import pytest

from airflow.models.dag import DagModel
from airflow.models.dag import DAG, DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -225,6 +227,10 @@ def test_should_raise_403_forbidden(self):
def _create_xcom_entry(
self, dag_id, run_id, logical_date, task_id, xcom_key, xcom_value="TEST_VALUE", *, backend=XCom
):
with DAG(dag_id=dag_id) as dag:
pass
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
with create_session() as session:
dagrun = DagRun(
dag_id=dag_id,
Expand All @@ -234,7 +240,9 @@ def _create_xcom_entry(
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, dag_version_id=dag_version.id)
ti.dag_id = dag_id
session.add(ti)
backend.set(
Expand Down Expand Up @@ -550,6 +558,7 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=
with create_session() as session:
dag = DagModel(dag_id=dag_id)
session.add(dag)
SerializedDagModel.write_dag(DAG(dag_id=dag_id), bundle_name="testing")
dagrun = DagRun(
dag_id=dag_id,
run_id=run_id,
Expand All @@ -558,13 +567,22 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
if mapped_ti:
for i in [0, 1]:
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i)
ti = TaskInstance(
EmptyOperator(task_id=task_id),
run_id=run_id,
map_index=i,
dag_version_id=dag_version.id,
)
ti.dag_id = dag_id
session.add(ti)
else:
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti = TaskInstance(
EmptyOperator(task_id=task_id), run_id=run_id, dag_version_id=dag_version.id
)
ti.dag_id = dag_id
session.add(ti)

Expand All @@ -587,6 +605,7 @@ def _create_invalid_xcom_entries(self, logical_date):
with create_session() as session:
dag = DagModel(dag_id="invalid_dag")
session.add(dag)
SerializedDagModel.write_dag(DAG(dag_id="invalid_dag"), bundle_name="testing")
dagrun = DagRun(
dag_id="invalid_dag",
run_id="invalid_run_id",
Expand All @@ -595,6 +614,8 @@ def _create_invalid_xcom_entries(self, logical_date):
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
dagrun1 = DagRun(
dag_id="invalid_dag",
run_id="not_this_run_id",
Expand All @@ -603,7 +624,9 @@ def _create_invalid_xcom_entries(self, logical_date):
run_type=DagRunType.MANUAL,
)
session.add(dagrun1)
ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
ti = TaskInstance(
EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id", dag_version_id=dag_version.id
)
ti.dag_id = "invalid_dag"
session.add(ti)
for i in [1, 2]:
Expand Down Expand Up @@ -683,6 +706,10 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids):
f"/api/v1/dags/{self.dag_id}/dagRuns/{self.run_id}/taskInstances/{self.task_id}/xcomEntries"
f"?{query_params}"
)
with DAG(dag_id=self.dag_id) as dag:
...
dag.sync_to_db()
SerializedDagModel.write_dag(dag, bundle_name="testing")
with create_session() as session:
dagrun = DagRun(
dag_id=self.dag_id,
Expand All @@ -692,7 +719,11 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids):
run_type=DagRunType.MANUAL,
)
session.add(dagrun)
ti = TaskInstance(EmptyOperator(task_id=self.task_id), run_id=self.run_id)
dag_version = DagVersion.get_latest_version(dag.dag_id)
assert dag_version
ti = TaskInstance(
EmptyOperator(task_id=self.task_id), run_id=self.run_id, dag_version_id=dag_version.id
)
ti.dag_id = self.dag_id
session.add(ti)

Expand Down
7 changes: 6 additions & 1 deletion tests/cli/commands/remote_commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.exceptions import AirflowException, DagRunNotFound
from airflow.executors.local_executor import LocalExecutor
from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.models.dag_version import DagVersion
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
Expand Down Expand Up @@ -619,6 +620,8 @@ def test_task_states_for_dag_run(self):
default_date2 = timezone.datetime(2016, 1, 9)
dag2.clear()
data_interval = dag2.timetable.infer_manual_data_interval(run_after=default_date2)
dag2.sync_to_db()
SerializedDagModel.write_dag(dag2, bundle_name="testing")
dagrun = dag2.create_dagrun(
run_id="test",
state=State.RUNNING,
Expand All @@ -630,7 +633,9 @@ def test_task_states_for_dag_run(self):
dag_version=None,
triggered_by=DagRunTriggeredByType.CLI,
)
ti2 = TaskInstance(task2, run_id=dagrun.run_id)
dag_version = DagVersion.get_latest_version(dag2.dag_id)
assert dag_version
ti2 = TaskInstance(task2, run_id=dagrun.run_id, dag_version_id=dag_version.id)
ti2.set_state(State.SUCCESS)
ti_start = ti2.start_date
ti_end = ti2.end_date
Expand Down
Loading

0 comments on commit 48376ac

Please sign in to comment.