Skip to content

Commit

Permalink
Make dag_version_id in TI non-nullable
Browse files Browse the repository at this point in the history
This is going to be a huge PR on tests side but it's
worth it. Ensuring that the TIs are associated with dagversions
helps maintain referral integrity and prevents inconsistent data.
  • Loading branch information
ephraimbuddy committed Feb 12, 2025
1 parent 50c4047 commit 9ed6b09
Show file tree
Hide file tree
Showing 39 changed files with 554 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
---
default_stages: [pre-commit, pre-push]
default_language_version:
python: python3
python: python3.12
node: 22.2.0
minimum_pre_commit_version: '3.2.0'
exclude: ^.*/.*_vendor/
Expand Down
2 changes: 0 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.state import DagRunState
Expand Down Expand Up @@ -393,7 +392,6 @@ def trigger_dag_run(
run_type=DagRunType.MANUAL,
triggered_by=DagRunTriggeredByType.REST_API,
external_trigger=True,
dag_version=DagVersion.get_latest_version(dag.dag_id),
state=DagRunState.QUEUED,
session=session,
)
Expand Down
8 changes: 3 additions & 5 deletions airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from fastapi import Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.orm import joinedload

from airflow import DAG
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
Expand Down Expand Up @@ -51,6 +50,7 @@
get_task_group_map,
)
from airflow.models import DagRun, TaskInstance
from airflow.models.dag_version import DagVersion

grid_router = AirflowRouter(prefix="/grid", tags=["Grid"])

Expand Down Expand Up @@ -94,9 +94,7 @@ def grid_data(
base_query = (
select(DagRun)
.join(DagRun.dag_run_note, isouter=True)
.join(DagRun.dag_version, isouter=True)
.select_from(DagRun)
.options(joinedload(DagRun.dag_version))
.where(DagRun.dag_id == dag.dag_id)
)

Expand Down Expand Up @@ -203,7 +201,7 @@ def grid_data(
task_node_map=task_node_map,
session=session,
)

dag_version = DagVersion.get_latest_version(dag.dag_id)
# Aggregate the Task Instances by DAG Run
grid_dag_runs = [
GridDAGRunwithTIs(
Expand All @@ -215,7 +213,7 @@ def grid_data(
run_type=dag_run.run_type,
data_interval_start=dag_run.data_interval_start,
data_interval_end=dag_run.data_interval_end,
version_number=dag_run.dag_version.version_number if dag_run.dag_version else None,
version_number=dag_version.version_number if dag_version else None,
note=dag_run.note,
task_instances=(
task_instance_summaries[dag_run.run_id] if dag_run.run_id in task_instance_summaries else []
Expand Down
6 changes: 5 additions & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models import TaskInstance
from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskReturnCode
from airflow.sdk.definitions.param import ParamsDict
Expand Down Expand Up @@ -209,7 +210,10 @@ def _get_ti(
f"run_id or logical_date of {logical_date_or_run_id!r} not found"
)
# TODO: Validate map_index is in range?
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index, dag_version_id=dag_version.id)
if dag_run in session:
session.add(ti)
ti.dag_run = dag_run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def upgrade():
batch_op.add_column(sa.Column("created_at", UtcDateTime(), nullable=False, default=timezone.utcnow))

with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False)))
batch_op.add_column(sa.Column("dag_version_id", UUIDType(binary=False), nullable=False))
batch_op.create_foreign_key(
batch_op.f("task_instance_dag_version_id_fkey"),
"dag_version",
Expand Down
9 changes: 7 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy_utils import UUIDType

from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG as SchedulerDAG
Expand Down Expand Up @@ -215,7 +216,9 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
def expand_mapped_task(
self, run_id: str, *, dag_version_id: UUIDType, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down Expand Up @@ -324,7 +327,9 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
ti = TaskInstance(
self, run_id=run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
11 changes: 9 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
NotMapped,
)
from airflow.models.base import _sentinel
from airflow.models.dag_version import DagVersion
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
from airflow.models.trigger import TRIGGER_FAIL_REPR, TriggerFailureReason
Expand Down Expand Up @@ -623,7 +624,10 @@ def run(
DagRun.logical_date == info.logical_date,
)
).one()
ti = TaskInstance(self, run_id=dag_run.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dag_run.run_id, dag_version_id=dag_version.id)
except NoResultFound:
# This is _mostly_ only used in tests
dr = DagRun(
Expand All @@ -640,7 +644,10 @@ def run(
triggered_by=DagRunTriggeredByType.TEST,
state=DagRunState.RUNNING,
)
ti = TaskInstance(self, run_id=dr.run_id)
dag_version = DagVersion.get_latest_version(self.dag_id, session=session)
if TYPE_CHECKING:
assert dag_version
ti = TaskInstance(self, run_id=dr.run_id, dag_version_id=dag_version.id)
ti.dag_run = dr
session.add(dr)
session.flush()
Expand Down
16 changes: 12 additions & 4 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,9 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# the db references.
ti.clear_db_references(session=session)
try:
expanded_tis, _ = TaskMap.expand_mapped_task(ti.task, self.run_id, session=session)
expanded_tis, _ = TaskMap.expand_mapped_task(
ti.dag_version_id, ti.task, self.run_id, session=session
)
except NotMapped: # Not a mapped task, nothing needed.
return None
if expanded_tis:
Expand Down Expand Up @@ -1205,7 +1207,11 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# It's enough to revise map index once per task id,
# checking the map index for each mapped task significantly slows down scheduling
if schedulable.task.task_id not in revised_map_index_task_ids:
ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
ready_tis.extend(
self._revise_map_indexes_if_mapped(
schedulable.dag_version_id, schedulable.task, session=session
)
)
revised_map_index_task_ids.add(schedulable.task.task_id)
ready_tis.append(schedulable)

Expand Down Expand Up @@ -1561,7 +1567,9 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]:
def _revise_map_indexes_if_mapped(
self, dag_version_id: UUIDType, task: Operator, *, session: Session
) -> Iterator[TI]:
"""
Check if task increased or reduced in length and handle appropriately.
Expand Down Expand Up @@ -1607,7 +1615,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
for index in range(total_length):
if index in existing_indexes:
continue
ti = TI(task, run_id=self.run_id, map_index=index, state=None)
ti = TI(task, run_id=self.run_id, map_index=index, state=None, dag_version_id=dag_version_id)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
7 changes: 4 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,6 @@ def _set_ti_attrs(target, source, include_dag_run=False):
target.dag_run.data_interval_start = source.dag_run.data_interval_start
target.dag_run.data_interval_end = source.dag_run.data_interval_end
target.dag_run.last_scheduling_decision = source.dag_run.last_scheduling_decision
target.dag_run.dag_version_id = source.dag_run.dag_version_id
target.dag_run.updated_at = source.dag_run.updated_at
target.dag_run.log_template_id = source.dag_run.log_template_id

Expand Down Expand Up @@ -1692,7 +1691,9 @@ class TaskInstance(Base, LoggingMixin):
next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON))

_task_display_property_value = Column("task_display_name", String(2000), nullable=True)
dag_version_id = Column(UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"))
dag_version_id = Column(
UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False
)
dag_version = relationship("DagVersion", back_populates="task_instances")
# If adding new fields here then remember to add them to
# _set_ti_attrs() or they won't display in the UI correctly
Expand Down Expand Up @@ -1805,7 +1806,7 @@ def insert_mapping(run_id: str, task: Operator, map_index: int, dag_version_id:
:meta private:
"""
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index)
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)

return {
Expand Down
10 changes: 8 additions & 2 deletions airflow/models/taskmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy_utils import UUIDType

from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -121,7 +122,9 @@ def variant(self) -> TaskMapVariant:
return TaskMapVariant.DICT

@classmethod
def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
def expand_mapped_task(
cls, dag_version_id: UUIDType, task, run_id: str, *, session: Session
) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand Down Expand Up @@ -224,7 +227,10 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq

for index in indexes_to_map:
# TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(task, run_id=run_id, map_index=index, state=state)

ti = TaskInstance(
task, run_id=run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
task.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
Expand Down
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0fe53d323bac2717a1f919100fa3cd8a97b011f01cc37a7752e136e7deba8be9
e8afeff47850a5850dad4f39da6344cd5d9befba56c4578d58cbad241b0f9cff
3 changes: 2 additions & 1 deletion docs/apache-airflow/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.models import TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DagBag
from airflow.models.taskmap import TaskMap
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -92,6 +93,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None):
mapped = MockOperator.partial(task_id="task_2", executor="default").expand(arg2=task1.output)

dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")
dag_version_id = DagVersion.get_latest_version(dr.dag_id, session=session).id

session.add(
TaskMap(
Expand Down Expand Up @@ -119,7 +121,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None):
itertools.repeat(TaskInstanceState.RUNNING, dag["running"]),
)
):
ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=state)
ti = TaskInstance(
mapped, run_id=dr.run_id, map_index=index, state=state, dag_version_id=dag_version_id
)
setattr(ti, "start_date", DEFAULT_DATETIME_1)
session.add(ti)

Expand All @@ -129,7 +133,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None):
self.app.dag_bag.sync_to_db("dags-folder", None)
session.flush()

TaskMap.expand_mapped_task(mapped, dr.run_id, session=session)
TaskMap.expand_mapped_task(dag_version_id, mapped, dr.run_id, session=session)

@pytest.fixture
def one_task_with_mapped_tis(self, dag_maker, session):
Expand Down
Loading

0 comments on commit 9ed6b09

Please sign in to comment.