From 91d8208c9ddeb060c5de40a61346fb309229e37b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 16 Sep 2024 14:29:13 +0800 Subject: [PATCH 1/9] Refactor bulk_save_to_db This function collects information from DAG objects, and creates/updates database rows against them. However, it handles A LOT of information, reading a lot of objects, touching a lot of models. The function is not very readable. A new module has been introduced in airflow.dag_processing.collection to encapsulate the logic previously in bulk_save_to_db. Some loops are broken down into multiple loops, so each loop does not do too much (which leads to a lot of long-living variables that reduce readability). Not much is changed aside from that, just mostly splitting things into separate steps to make things clearer. --- airflow/dag_processing/collection.py | 402 +++++++++++++++++++++++++++ airflow/datasets/__init__.py | 16 +- airflow/models/dag.py | 343 ++--------------------- airflow/timetables/base.py | 5 +- tests/datasets/test_dataset.py | 2 +- 5 files changed, 444 insertions(+), 324 deletions(-) create mode 100644 airflow/dag_processing/collection.py diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py new file mode 100644 index 0000000000000..e7a658539e5ca --- /dev/null +++ b/airflow/dag_processing/collection.py @@ -0,0 +1,402 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Utility code that write DAGs in bulk into the database. + +This should generally only be called by internal methods such as +``DagBag._sync_to_db``, ``DAG.bulk_write_to_db``. + +:meta private: +""" + +from __future__ import annotations + +import itertools +import logging +from typing import TYPE_CHECKING, NamedTuple + +from sqlalchemy import func, or_, select +from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.sql import expression + +from airflow.datasets import Dataset, DatasetAlias +from airflow.datasets.manager import dataset_manager +from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag +from airflow.models.dagrun import DagRun +from airflow.models.dataset import ( + DagScheduleDatasetAliasReference, + DagScheduleDatasetReference, + DatasetAliasModel, + DatasetModel, + TaskOutletDatasetReference, +) +from airflow.utils.sqlalchemy import with_row_locks +from airflow.utils.timezone import utcnow +from airflow.utils.types import DagRunType + +if TYPE_CHECKING: + from collections.abc import Collection, Iterable, Iterator + + from sqlalchemy.orm import Session + from sqlalchemy.sql import Select + + from airflow.typing_compat import Self + +log = logging.getLogger(__name__) + + +def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str, DagModel]: + """ + Collect DagModel objects from DAG objects. + + An existing DagModel is fetched if there's a matching ID in the database. + Otherwise, a new DagModel is created and added to the session. + """ + stmt = ( + select(DagModel) + .options(joinedload(DagModel.tags, innerjoin=False)) + .where(DagModel.dag_id.in_(dags)) + .options(joinedload(DagModel.schedule_dataset_references)) + .options(joinedload(DagModel.schedule_dataset_alias_references)) + .options(joinedload(DagModel.task_outlet_dataset_references)) + ) + stmt = with_row_locks(stmt, of=DagModel, session=session) + existing_orm_dags = {dm.dag_id: dm for dm in session.scalars(stmt).unique()} + + for dag_id, dag in dags.items(): + if dag_id in existing_orm_dags: + continue + orm_dag = DagModel(dag_id=dag_id) + if dag.is_paused_upon_creation is not None: + orm_dag.is_paused = dag.is_paused_upon_creation + orm_dag.tags = [] + log.info("Creating ORM DAG for %s", dag_id) + session.add(orm_dag) + existing_orm_dags[dag_id] = orm_dag + + return existing_orm_dags + + +def create_orm_dag(dag: DAG, session: Session) -> DagModel: + orm_dag = DagModel(dag_id=dag.dag_id) + if dag.is_paused_upon_creation is not None: + orm_dag.is_paused = dag.is_paused_upon_creation + orm_dag.tags = [] + log.info("Creating ORM DAG for %s", dag.dag_id) + session.add(orm_dag) + return orm_dag + + +def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select: + """Build a select statement for retrieve the last automated run for each dag.""" + if len(dag_ids) == 1: # Index optimized fast path to avoid more complicated & slower groupby queryplan. + (dag_id,) = dag_ids + last_automated_runs_subq = ( + select(func.max(DagRun.execution_date).label("max_execution_date")) + .where( + DagRun.dag_id == dag_id, + or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), + ) + .scalar_subquery() + ) + query = select(DagRun).where( + DagRun.dag_id == dag_id, + DagRun.execution_date == last_automated_runs_subq, + ) + else: + last_automated_runs_subq = ( + select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) + .where( + DagRun.dag_id.in_(dag_ids), + DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), + ) + .group_by(DagRun.dag_id) + .subquery() + ) + query = select(DagRun).where( + DagRun.dag_id == last_automated_runs_subq.c.dag_id, + DagRun.execution_date == last_automated_runs_subq.c.max_execution_date, + ) + return query.options( + load_only( + DagRun.dag_id, + DagRun.execution_date, + DagRun.data_interval_start, + DagRun.data_interval_end, + ) + ) + + +class _RunInfo(NamedTuple): + latest_runs: dict[str, DagRun] + num_active_runs: dict[str, int] + + @classmethod + def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self: + # Skip these queries entirely if no DAGs can be scheduled to save time. + if not any(dag.timetable.can_be_scheduled for dag in dags.values()): + return cls({}, {}) + return cls( + {run.dag_id: run for run in session.scalars(_get_latest_runs_stmt(dag_ids=dags))}, + DagRun.active_runs_of_dags(dag_ids=dags, session=session), + ) + + +def update_orm_dags( + source_dags: dict[str, DAG], + target_dags: dict[str, DagModel], + *, + processor_subdir: str | None = None, + session: Session, +) -> None: + """ + Apply DAG attributes to DagModel objects. + + Objects in ``target_dags`` are modified in-place. + """ + run_info = _RunInfo.calculate(source_dags, session=session) + + for dag_id, dm in sorted(target_dags.items()): + dag = source_dags[dag_id] + dm.fileloc = dag.fileloc + dm.owners = dag.owner + dm.is_active = True + dm.has_import_errors = False + dm.last_parsed_time = utcnow() + dm.default_view = dag.default_view + dm._dag_display_property_value = dag._dag_display_property_value + dm.description = dag.description + dm.max_active_tasks = dag.max_active_tasks + dm.max_active_runs = dag.max_active_runs + dm.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs + dm.has_task_concurrency_limits = any( + t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None for t in dag.tasks + ) + dm.timetable_summary = dag.timetable.summary + dm.timetable_description = dag.timetable.description + dm.dataset_expression = dag.timetable.dataset_condition.as_expression() + dm.processor_subdir = processor_subdir + + last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id) + if last_automated_run is None: + last_automated_data_interval = None + else: + last_automated_data_interval = dag.get_run_data_interval(last_automated_run) + if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs: + dm.next_dagrun_create_after = None + else: + dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) + + if not dag.timetable.dataset_condition: + dm.schedule_dataset_references = [] + dm.schedule_dataset_alias_references = [] + # FIXME: STORE NEW REFERENCES. + + dag_tags = set(dag.tags or ()) + for orm_tag in (dm_tags := list(dm.tags or [])): + if orm_tag.name not in dag_tags: + session.delete(orm_tag) + dm.tags.remove(orm_tag) + orm_tag_names = {t.name for t in dm_tags} + for dag_tag in dag_tags: + if dag_tag not in orm_tag_names: + dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) + dm.tags.append(dag_tag_orm) + session.add(dag_tag_orm) + + dm_links = dm.dag_owner_links or [] + for dm_link in dm_links: + if dm_link not in dag.owner_links: + session.delete(dm_link) + for owner_name, owner_link in dag.owner_links.items(): + dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) + session.add(dag_owner_orm) + + +def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]: + for dag in dags: + for _, dataset in dag.timetable.dataset_condition.iter_datasets(): + yield dataset + for task in dag.task_dict.values(): + for obj in itertools.chain(task.inlets, task.outlets): + if isinstance(obj, Dataset): + yield obj + + +def _find_all_dataset_aliases(dags: Iterable[DAG]) -> Iterator[DatasetAlias]: + for dag in dags: + for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases(): + yield alias + for task in dag.task_dict.values(): + for obj in itertools.chain(task.inlets, task.outlets): + if isinstance(obj, DatasetAlias): + yield obj + + +class DatasetCollection(NamedTuple): + """Datasets collected from DAGs.""" + + schedule_dataset_references: dict[str, list[Dataset]] + schedule_dataset_alias_references: dict[str, list[DatasetAlias]] + outlet_references: dict[str, list[tuple[str, Dataset]]] + datasets: dict[str, Dataset] + dataset_aliases: dict[str, DatasetAlias] + + @classmethod + def collect(cls, dags: dict[str, DAG]) -> Self: + coll = cls( + schedule_dataset_references={ + dag_id: [dataset for _, dataset in dag.timetable.dataset_condition.iter_datasets()] + for dag_id, dag in dags.items() + }, + schedule_dataset_alias_references={ + dag_id: [alias for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases()] + for dag_id, dag in dags.items() + }, + outlet_references={ + dag_id: [ + (task_id, outlet) + for task_id, task in dag.task_dict.items() + for outlet in task.outlets + if isinstance(outlet, Dataset) + ] + for dag_id, dag in dags.items() + }, + datasets={dataset.uri: dataset for dataset in _find_all_datasets(dags.values())}, + dataset_aliases={alias.name: alias for alias in _find_all_dataset_aliases(dags.values())}, + ) + return coll + + def write_datasets(self, *, session: Session) -> dict[str, DatasetModel]: + orm_datasets: dict[str, DatasetModel] = { + dm.uri: dm + for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets))) + } + + def _resolve_dataset_addition() -> Iterator[DatasetModel]: + for uri, dataset in self.datasets.items(): + try: + dm = orm_datasets[uri] + except KeyError: + dm = orm_datasets[uri] = DatasetModel.from_public(dataset) + yield dm + else: + # The orphaned flag was bulk-set to True before parsing, so we + # don't need to handle rows in the db without a public entry. + dm.is_orphaned = expression.false() + dm.extra = dataset.extra + + dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session) + return orm_datasets + + def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: + orm_aliases: dict[str, DatasetAliasModel] = { + da.name: da + for da in session.scalars( + select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases)) + ) + } + for name, alias in self.dataset_aliases.items(): + try: + da = orm_aliases[name] + except KeyError: + da = orm_aliases[name] = DatasetAliasModel.from_public(alias) + session.add(da) + return orm_aliases + + def write_dag_dataset_references( + self, + dags: dict[str, DagModel], + datasets: dict[str, DatasetModel], + *, + session: Session, + ) -> None: + for dag_id, references in self.schedule_dataset_references.items(): + # Optimization: no references at all, just clear everything. + if not references: + dags[dag_id].schedule_dataset_references = [] + continue + referenced_dataset_ids = {dataset.id for dataset in (datasets[r.uri] for r in references)} + orm_refs = {r.dataset_id: r for r in dags[dag_id].schedule_dataset_references} + for dataset_id, ref in orm_refs.items(): + if dataset_id not in referenced_dataset_ids: + session.delete(ref) + session.bulk_save_objects( + ( + DagScheduleDatasetReference(dataset_id=dataset_id, dag_id=dag_id) + for dataset_id in referenced_dataset_ids + if dataset_id not in orm_refs + ), + preserve_order=False, + ) + + def write_dag_dataset_alias_references( + self, + dags: dict[str, DagModel], + aliases: dict[str, DatasetAliasModel], + *, + session: Session, + ) -> None: + for dag_id, references in self.schedule_dataset_alias_references.items(): + # Optimization: no references at all, just clear everything. + if not references: + dags[dag_id].schedule_dataset_alias_references = [] + continue + referenced_alias_ids = {alias.id for alias in (aliases[r.name] for r in references)} + orm_refs = {a.alias_id: a for a in dags[dag_id].schedule_dataset_alias_references} + for alias_id, ref in orm_refs.items(): + if alias_id not in referenced_alias_ids: + session.delete(ref) + session.bulk_save_objects( + ( + DagScheduleDatasetAliasReference(alias_id=alias_id, dag_id=dag_id) + for alias_id in referenced_alias_ids + if alias_id not in orm_refs + ), + preserve_order=False, + ) + + def write_task_dataset_references( + self, + dags: dict[str, DagModel], + datasets: dict[str, DatasetModel], + *, + session: Session, + ) -> None: + for dag_id, references in self.outlet_references.items(): + # Optimization: no references at all, just clear everything. + if not references: + dags[dag_id].task_outlet_dataset_references = [] + continue + referenced_outlets = { + (task_id, dataset.id) + for task_id, dataset in ((task_id, datasets[d.uri]) for task_id, d in references) + } + orm_refs = {(r.task_id, r.dataset_id): r for r in dags[dag_id].task_outlet_dataset_references} + for key, ref in orm_refs.items(): + if key not in referenced_outlets: + session.delete(ref) + session.bulk_save_objects( + ( + TaskOutletDatasetReference(dataset_id=dataset_id, dag_id=dag_id, task_id=task_id) + for task_id, dataset_id in referenced_outlets + if (task_id, dataset_id) not in orm_refs + ), + preserve_order=False, + ) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index cd57078095b7b..6f7ae99ff7417 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -194,7 +194,7 @@ def evaluate(self, statuses: dict[str, bool]) -> bool: def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: raise NotImplementedError - def iter_dataset_aliases(self) -> Iterator[DatasetAlias]: + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: raise NotImplementedError def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: @@ -212,6 +212,12 @@ class DatasetAlias(BaseDataset): name: str + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: + return iter(()) + + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + yield self.name, self + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ Iterate a dataset alias as dag dependency. @@ -294,7 +300,7 @@ def as_expression(self) -> Any: def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: yield self.uri, self - def iter_dataset_aliases(self) -> Iterator[DatasetAlias]: + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: return iter(()) def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -339,7 +345,7 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: yield k, v seen.add(k) - def iter_dataset_aliases(self) -> Iterator[DatasetAlias]: + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: """Filter dataest aliases in the condition.""" for o in self.objects: yield from o.iter_dataset_aliases() @@ -399,8 +405,8 @@ def as_expression(self) -> Any: """ return {"alias": self.name} - def iter_dataset_aliases(self) -> Iterator[DatasetAlias]: - yield DatasetAlias(self.name) + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + yield self.name, DatasetAlias(self.name) def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: """ diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 6545293ccf89c..cdb1fdb6f8fd0 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -74,7 +74,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, joinedload, load_only, relationship +from sqlalchemy.orm import backref, relationship from sqlalchemy.sql import Select, expression import airflow.templates @@ -82,7 +82,6 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll -from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowException, DuplicateTaskIdFound, @@ -100,11 +99,7 @@ from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.dataset import ( - DatasetAliasModel, - DatasetDagRunQueue, - DatasetModel, -) +from airflow.models.dataset import DatasetDagRunQueue from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -2637,7 +2632,7 @@ def bulk_write_to_db( cls, dags: Collection[DAG], processor_subdir: str | None = None, - session=NEW_SESSION, + session: Session = NEW_SESSION, ): """ Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB. @@ -2648,323 +2643,37 @@ def bulk_write_to_db( if not dags: return - log.info("Sync %s DAGs", len(dags)) - dag_by_ids = {dag.dag_id: dag for dag in dags} - - dag_ids = set(dag_by_ids) - query = ( - select(DagModel) - .options(joinedload(DagModel.tags, innerjoin=False)) - .where(DagModel.dag_id.in_(dag_ids)) - .options(joinedload(DagModel.schedule_dataset_references)) - .options(joinedload(DagModel.schedule_dataset_alias_references)) - .options(joinedload(DagModel.task_outlet_dataset_references)) - ) - query = with_row_locks(query, of=DagModel, session=session) - orm_dags: list[DagModel] = session.scalars(query).unique().all() - existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags} - missing_dag_ids = dag_ids.difference(existing_dags.keys()) - - for missing_dag_id in missing_dag_ids: - orm_dag = DagModel(dag_id=missing_dag_id) - dag = dag_by_ids[missing_dag_id] - if dag.is_paused_upon_creation is not None: - orm_dag.is_paused = dag.is_paused_upon_creation - orm_dag.tags = [] - log.info("Creating ORM DAG for %s", dag.dag_id) - session.add(orm_dag) - orm_dags.append(orm_dag) - - latest_runs: dict[str, DagRun] = {} - num_active_runs: dict[str, int] = {} - # Skip these queries entirely if no DAGs can be scheduled to save time. - if any(dag.timetable.can_be_scheduled for dag in dags): - # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query) - query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys())) - latest_runs = {run.dag_id: run for run in session.scalars(query)} - - # Get number of active dagruns for all dags we are processing as a single query. - num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session) - - filelocs = [] - - for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): - dag = dag_by_ids[orm_dag.dag_id] - filelocs.append(dag.fileloc) - orm_dag.fileloc = dag.fileloc - orm_dag.owners = dag.owner - orm_dag.is_active = True - orm_dag.has_import_errors = False - orm_dag.last_parsed_time = timezone.utcnow() - orm_dag.default_view = dag.default_view - orm_dag._dag_display_property_value = dag._dag_display_property_value - orm_dag.description = dag.description - orm_dag.max_active_tasks = dag.max_active_tasks - orm_dag.max_active_runs = dag.max_active_runs - orm_dag.max_consecutive_failed_dag_runs = dag.max_consecutive_failed_dag_runs - orm_dag.has_task_concurrency_limits = any( - t.max_active_tis_per_dag is not None or t.max_active_tis_per_dagrun is not None - for t in dag.tasks - ) - orm_dag.timetable_summary = dag.timetable.summary - orm_dag.timetable_description = dag.timetable.description - orm_dag.dataset_expression = dag.timetable.dataset_condition.as_expression() - - orm_dag.processor_subdir = processor_subdir - - last_automated_run: DagRun | None = latest_runs.get(dag.dag_id) - if last_automated_run is None: - last_automated_data_interval = None - else: - last_automated_data_interval = dag.get_run_data_interval(last_automated_run) - if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs: - orm_dag.next_dagrun_create_after = None - else: - orm_dag.calculate_dagrun_date_fields(dag, last_automated_data_interval) - - dag_tags = set(dag.tags or {}) - orm_dag_tags = list(orm_dag.tags or []) - for orm_tag in orm_dag_tags: - if orm_tag.name not in dag_tags: - session.delete(orm_tag) - orm_dag.tags.remove(orm_tag) - orm_tag_names = {t.name for t in orm_dag_tags} - for dag_tag in dag_tags: - if dag_tag not in orm_tag_names: - dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id) - orm_dag.tags.append(dag_tag_orm) - session.add(dag_tag_orm) - - orm_dag_links = orm_dag.dag_owner_links or [] - for orm_dag_link in orm_dag_links: - if orm_dag_link not in dag.owner_links: - session.delete(orm_dag_link) - for owner_name, owner_link in dag.owner_links.items(): - dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link) - session.add(dag_owner_orm) - - DagCode.bulk_sync_to_db(filelocs, session=session) - - from airflow.datasets import Dataset - from airflow.models.dataset import ( - DagScheduleDatasetAliasReference, - DagScheduleDatasetReference, - DatasetModel, - TaskOutletDatasetReference, - ) - - dag_references: dict[str, set[tuple[Literal["dataset", "dataset-alias"], str]]] = defaultdict(set) - outlet_references = defaultdict(set) - # We can't use a set here as we want to preserve order - outlet_dataset_models: dict[DatasetModel, None] = {} - input_dataset_models: dict[DatasetModel, None] = {} - outlet_dataset_alias_models: set[DatasetAliasModel] = set() - input_dataset_alias_models: set[DatasetAliasModel] = set() - - # here we go through dags and tasks to check for dataset references - # if there are now None and previously there were some, we delete them - # if there are now *any*, we add them to the above data structures, and - # later we'll persist them to the database. - for dag in dags: - curr_orm_dag = existing_dags.get(dag.dag_id) - if not (dataset_condition := dag.timetable.dataset_condition): - if curr_orm_dag: - if curr_orm_dag.schedule_dataset_references: - curr_orm_dag.schedule_dataset_references = [] - if curr_orm_dag.schedule_dataset_alias_references: - curr_orm_dag.schedule_dataset_alias_references = [] - else: - for _, dataset in dataset_condition.iter_datasets(): - dag_references[dag.dag_id].add(("dataset", dataset.uri)) - input_dataset_models[DatasetModel.from_public(dataset)] = None - - for dataset_alias in dataset_condition.iter_dataset_aliases(): - dag_references[dag.dag_id].add(("dataset-alias", dataset_alias.name)) - input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias)) - - curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references - for task in dag.tasks: - dataset_outlets: list[Dataset] = [] - dataset_alias_outlets: list[DatasetAlias] = [] - for outlet in task.outlets: - if isinstance(outlet, Dataset): - dataset_outlets.append(outlet) - elif isinstance(outlet, DatasetAlias): - dataset_alias_outlets.append(outlet) - - if not dataset_outlets: - if curr_outlet_references: - this_task_outlet_refs = [ - x - for x in curr_outlet_references - if x.dag_id == dag.dag_id and x.task_id == task.task_id - ] - for ref in this_task_outlet_refs: - curr_outlet_references.remove(ref) - - for d in dataset_outlets: - outlet_dataset_models[DatasetModel.from_public(d)] = None - outlet_references[(task.dag_id, task.task_id)].add(d.uri) - - for d_a in dataset_alias_outlets: - outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a)) - - all_dataset_models = outlet_dataset_models - all_dataset_models.update(input_dataset_models) - - # store datasets - stored_dataset_models: dict[str, DatasetModel] = {} - new_dataset_models: list[DatasetModel] = [] - for dataset in all_dataset_models: - stored_dataset_model = session.scalar( - select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1) - ) - if stored_dataset_model: - # Some datasets may have been previously unreferenced, and therefore orphaned by the - # scheduler. But if we're here, then we have found that dataset again in our DAGs, which - # means that it is no longer an orphan, so set is_orphaned to False. - stored_dataset_model.is_orphaned = expression.false() - stored_dataset_models[stored_dataset_model.uri] = stored_dataset_model - else: - new_dataset_models.append(dataset) - dataset_manager.create_datasets(dataset_models=new_dataset_models, session=session) - stored_dataset_models.update( - {dataset_model.uri: dataset_model for dataset_model in new_dataset_models} + from airflow.dag_processing.collection import ( + DatasetCollection, + collect_orm_dags, + create_orm_dag, + update_orm_dags, ) - del new_dataset_models - del all_dataset_models - - # store dataset aliases - all_datasets_alias_models = input_dataset_alias_models | outlet_dataset_alias_models - stored_dataset_alias_models: dict[str, DatasetAliasModel] = {} - new_dataset_alias_models: set[DatasetAliasModel] = set() - if all_datasets_alias_models: - all_dataset_alias_names = { - dataset_alias_model.name for dataset_alias_model in all_datasets_alias_models - } - - stored_dataset_alias_models = { - dsa_m.name: dsa_m - for dsa_m in session.scalars( - select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names)) - ).fetchall() - } - - if stored_dataset_alias_models: - new_dataset_alias_models = { - dataset_alias_model - for dataset_alias_model in all_datasets_alias_models - if dataset_alias_model.name not in stored_dataset_alias_models.keys() - } - else: - new_dataset_alias_models = all_datasets_alias_models - - session.add_all(new_dataset_alias_models) - session.flush() - stored_dataset_alias_models.update( - { - dataset_alias_model.name: dataset_alias_model - for dataset_alias_model in new_dataset_alias_models - } + log.info("Sync %s DAGs", len(dags)) + dags_by_ids = {dag.dag_id: dag for dag in dags} + del dags + + orm_dags = collect_orm_dags(dags_by_ids, session=session) + orm_dags.update( + (dag_id, create_orm_dag(dag, session=session)) + for dag_id, dag in dags_by_ids.items() + if dag_id not in orm_dags ) - del new_dataset_alias_models - del all_datasets_alias_models + update_orm_dags(dags_by_ids, orm_dags, processor_subdir=processor_subdir, session=session) + DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), session=session) - # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias references - for dag_id, base_dataset_list in dag_references.items(): - dag_refs_needed = { - DagScheduleDatasetReference( - dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id - ) - if base_dataset_type == "dataset" - else DagScheduleDatasetAliasReference( - alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id - ) - for base_dataset_type, base_dataset_identifier in base_dataset_list - } + dataset_collection = DatasetCollection.collect(dags_by_ids) + orm_datasets = dataset_collection.write_datasets(session=session) + orm_dataset_aliases = dataset_collection.write_dataset_aliases(session=session) + session.flush() # This populates id so we can start creating references. - # if isinstance(base_dataset, Dataset) - - dag_refs_stored = ( - set(existing_dags.get(dag_id).schedule_dataset_references) # type: ignore - | set(existing_dags.get(dag_id).schedule_dataset_alias_references) # type: ignore - if existing_dags.get(dag_id) - else set() - ) - dag_refs_to_add = dag_refs_needed - dag_refs_stored - session.bulk_save_objects(dag_refs_to_add) - for obj in dag_refs_stored - dag_refs_needed: - session.delete(obj) - - existing_task_outlet_refs_dict = defaultdict(set) - for dag_id, orm_dag in existing_dags.items(): - for todr in orm_dag.task_outlet_dataset_references: - existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) - - # reconcile task-outlet-dataset references - for (dag_id, task_id), uri_list in outlet_references.items(): - task_refs_needed = { - TaskOutletDatasetReference( - dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, task_id=task_id - ) - for uri in uri_list - } - task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)] - task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored} - session.bulk_save_objects(task_refs_to_add) - for obj in task_refs_stored - task_refs_needed: - session.delete(obj) - - # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller - # decide when to commit + dataset_collection.write_dag_dataset_references(orm_dags, orm_datasets, session=session) + dataset_collection.write_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) + dataset_collection.write_task_dataset_references(orm_dags, orm_datasets, session=session) session.flush() - @classmethod - def _get_latest_runs_stmt(cls, dags: list[str]) -> Select: - """ - Build a select statement for retrieve the last automated run for each dag. - - :param dags: dags to query - """ - if len(dags) == 1: - # Index optimized fast path to avoid more complicated & slower groupby queryplan - existing_dag_id = dags[0] - last_automated_runs_subq = ( - select(func.max(DagRun.execution_date).label("max_execution_date")) - .where( - DagRun.dag_id == existing_dag_id, - DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), - ) - .scalar_subquery() - ) - query = select(DagRun).where( - DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq - ) - else: - last_automated_runs_subq = ( - select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date")) - .where( - DagRun.dag_id.in_(dags), - DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), - ) - .group_by(DagRun.dag_id) - .subquery() - ) - query = select(DagRun).where( - DagRun.dag_id == last_automated_runs_subq.c.dag_id, - DagRun.execution_date == last_automated_runs_subq.c.max_execution_date, - ) - return query.options( - load_only( - DagRun.dag_id, - DagRun.execution_date, - DagRun.data_interval_start, - DagRun.data_interval_end, - ) - ) - @provide_session def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION): """ diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index ce701794a4c63..5d97591856b5a 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from pendulum import DateTime - from airflow.datasets import Dataset + from airflow.datasets import Dataset, DatasetAlias from airflow.serialization.dag_dependency import DagDependency from airflow.utils.types import DagRunType @@ -57,6 +57,9 @@ def evaluate(self, statuses: dict[str, bool]) -> bool: def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: return iter(()) + def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]: + return iter(()) + def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]: return iter(()) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 940e445669cce..8221a5aea8aa3 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -145,7 +145,7 @@ def test_dataset_iter_dataset_aliases(): DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")), ) assert list(base_dataset.iter_dataset_aliases()) == [ - DatasetAlias(f"example-alias-{i}") for i in range(1, 6) + (f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in range(1, 6) ] From e449c3e1ff5257d656b2c22f1a4f3429f1796ac2 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 18 Sep 2024 13:49:23 +0800 Subject: [PATCH 2/9] Don't do an extra flush for Dataset The dataset manager already does this. --- airflow/dag_processing/collection.py | 1 + airflow/models/dag.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index e7a658539e5ca..3f6b2f1495f70 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -318,6 +318,7 @@ def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMo except KeyError: da = orm_aliases[name] = DatasetAliasModel.from_public(alias) session.add(da) + session.flush() return orm_aliases def write_dag_dataset_references( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index cdb1fdb6f8fd0..bf85273f848df 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2667,7 +2667,6 @@ def bulk_write_to_db( dataset_collection = DatasetCollection.collect(dags_by_ids) orm_datasets = dataset_collection.write_datasets(session=session) orm_dataset_aliases = dataset_collection.write_dataset_aliases(session=session) - session.flush() # This populates id so we can start creating references. dataset_collection.write_dag_dataset_references(orm_dags, orm_datasets, session=session) dataset_collection.write_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) From afc116a94258a00114c3f702e26ce9c77171f195 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 18 Sep 2024 17:55:17 +0800 Subject: [PATCH 3/9] Improve session handling a bit better --- airflow/dag_processing/collection.py | 28 +++++++++------------------- airflow/datasets/manager.py | 2 -- airflow/models/dag.py | 1 + 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 3f6b2f1495f70..c41b2b2f3a7a0 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -318,7 +318,6 @@ def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMo except KeyError: da = orm_aliases[name] = DatasetAliasModel.from_public(alias) session.add(da) - session.flush() return orm_aliases def write_dag_dataset_references( @@ -339,12 +338,9 @@ def write_dag_dataset_references( if dataset_id not in referenced_dataset_ids: session.delete(ref) session.bulk_save_objects( - ( - DagScheduleDatasetReference(dataset_id=dataset_id, dag_id=dag_id) - for dataset_id in referenced_dataset_ids - if dataset_id not in orm_refs - ), - preserve_order=False, + DagScheduleDatasetReference(dataset_id=dataset_id, dag_id=dag_id) + for dataset_id in referenced_dataset_ids + if dataset_id not in orm_refs ) def write_dag_dataset_alias_references( @@ -365,12 +361,9 @@ def write_dag_dataset_alias_references( if alias_id not in referenced_alias_ids: session.delete(ref) session.bulk_save_objects( - ( - DagScheduleDatasetAliasReference(alias_id=alias_id, dag_id=dag_id) - for alias_id in referenced_alias_ids - if alias_id not in orm_refs - ), - preserve_order=False, + DagScheduleDatasetAliasReference(alias_id=alias_id, dag_id=dag_id) + for alias_id in referenced_alias_ids + if alias_id not in orm_refs ) def write_task_dataset_references( @@ -394,10 +387,7 @@ def write_task_dataset_references( if key not in referenced_outlets: session.delete(ref) session.bulk_save_objects( - ( - TaskOutletDatasetReference(dataset_id=dataset_id, dag_id=dag_id, task_id=task_id) - for task_id, dataset_id in referenced_outlets - if (task_id, dataset_id) not in orm_refs - ), - preserve_order=False, + TaskOutletDatasetReference(dataset_id=dataset_id, dag_id=dag_id, task_id=task_id) + for task_id, dataset_id in referenced_outlets + if (task_id, dataset_id) not in orm_refs ) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 058eef6ab8922..19f6913fffbeb 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -62,8 +62,6 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session) """Create new datasets.""" for dataset_model in dataset_models: session.add(dataset_model) - session.flush() - for dataset_model in dataset_models: self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index bf85273f848df..d682cbcfe60e9 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2667,6 +2667,7 @@ def bulk_write_to_db( dataset_collection = DatasetCollection.collect(dags_by_ids) orm_datasets = dataset_collection.write_datasets(session=session) orm_dataset_aliases = dataset_collection.write_dataset_aliases(session=session) + session.flush() dataset_collection.write_dag_dataset_references(orm_dags, orm_datasets, session=session) dataset_collection.write_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) From 96c89c5e21ccd178163c18a22e2ce6ebdd7f77e8 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 18 Sep 2024 18:46:47 +0800 Subject: [PATCH 4/9] Move test to accomodate function move --- tests/dag_processing/test_collection.py | 64 +++++++++++++++++++++++++ tests/models/test_dag.py | 41 ---------------- 2 files changed, 64 insertions(+), 41 deletions(-) create mode 100644 tests/dag_processing/test_collection.py diff --git a/tests/dag_processing/test_collection.py b/tests/dag_processing/test_collection.py new file mode 100644 index 0000000000000..4d5a6736ad07e --- /dev/null +++ b/tests/dag_processing/test_collection.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import warnings + +from sqlalchemy.exc import SAWarning + +from airflow.dag_processing.collection import _get_latest_runs_stmt + + +def test_statement_latest_runs_one_dag(): + with warnings.catch_warnings(): + warnings.simplefilter("error", category=SAWarning) + + stmt = _get_latest_runs_stmt(["fake-dag"]) + compiled_stmt = str(stmt.compile()) + actual = [x.strip() for x in compiled_stmt.splitlines()] + expected = [ + "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, " + "dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = (" + "SELECT max(dag_run.logical_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", + ] + assert actual == expected, compiled_stmt + + +def test_statement_latest_runs_many_dag(): + with warnings.catch_warnings(): + warnings.simplefilter("error", category=SAWarning) + + stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"]) + compiled_stmt = str(stmt.compile()) + actual = [x.strip() for x in compiled_stmt.splitlines()] + expected = [ + "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, " + "dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, " + "max(dag_run.logical_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) " + "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", + "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date", + ] + assert actual == expected, compiled_stmt diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index cb6d4d4ed4c41..093fdcae2f12f 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -23,7 +23,6 @@ import os import pickle import re -import warnings import weakref from datetime import timedelta from importlib import reload @@ -37,7 +36,6 @@ import pytest import time_machine from sqlalchemy import inspect, select -from sqlalchemy.exc import SAWarning from airflow import settings from airflow.configuration import conf @@ -3992,42 +3990,3 @@ def test_validate_setup_teardown_trigger_rule(self): Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS." ): dag.validate_setup_teardown() - - -def test_statement_latest_runs_one_dag(): - with warnings.catch_warnings(): - warnings.simplefilter("error", category=SAWarning) - - stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"]) - compiled_stmt = str(stmt.compile()) - actual = [x.strip() for x in compiled_stmt.splitlines()] - expected = [ - "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, " - "dag_run.data_interval_start, dag_run.data_interval_end", - "FROM dag_run", - "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = (" - "SELECT max(dag_run.logical_date) AS max_execution_date", - "FROM dag_run", - "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", - ] - assert actual == expected, compiled_stmt - - -def test_statement_latest_runs_many_dag(): - with warnings.catch_warnings(): - warnings.simplefilter("error", category=SAWarning) - - stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"]) - compiled_stmt = str(stmt.compile()) - actual = [x.strip() for x in compiled_stmt.splitlines()] - expected = [ - "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, " - "dag_run.data_interval_start, dag_run.data_interval_end", - "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, " - "max(dag_run.logical_date) AS max_execution_date", - "FROM dag_run", - "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) " - "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", - "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date", - ] - assert actual == expected, compiled_stmt From d1862e6ad50d79fa10f11e4d552b2e22f66784db Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 07:17:28 +0800 Subject: [PATCH 5/9] Optimize database calls when there are not objects This saves a few calls when no dataset/alias references were found in any DAGs. --- airflow/dag_processing/collection.py | 31 +++++++++++++++++++++------- airflow/models/dag.py | 13 ++++++------ 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index c41b2b2f3a7a0..f990d8473176c 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -283,7 +283,10 @@ def collect(cls, dags: dict[str, DAG]) -> Self: ) return coll - def write_datasets(self, *, session: Session) -> dict[str, DatasetModel]: + def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]: + # Optimization: skip all database calls if no datasets were collected. + if not self.datasets: + return {} orm_datasets: dict[str, DatasetModel] = { dm.uri: dm for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets))) @@ -305,7 +308,10 @@ def _resolve_dataset_addition() -> Iterator[DatasetModel]: dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session) return orm_datasets - def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: + def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: + # Optimization: skip all database calls if no dataset aliases were collected. + if not self.dataset_aliases: + return {} orm_aliases: dict[str, DatasetAliasModel] = { da.name: da for da in session.scalars( @@ -320,15 +326,18 @@ def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMo session.add(da) return orm_aliases - def write_dag_dataset_references( + def add_dag_dataset_references( self, dags: dict[str, DagModel], datasets: dict[str, DatasetModel], *, session: Session, ) -> None: + # Optimization: No datasets means there are no references to update. + if not datasets: + return for dag_id, references in self.schedule_dataset_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_references = [] continue @@ -343,15 +352,18 @@ def write_dag_dataset_references( if dataset_id not in orm_refs ) - def write_dag_dataset_alias_references( + def add_dag_dataset_alias_references( self, dags: dict[str, DagModel], aliases: dict[str, DatasetAliasModel], *, session: Session, ) -> None: + # Optimization: No aliases means there are no references to update. + if not aliases: + return for dag_id, references in self.schedule_dataset_alias_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_alias_references = [] continue @@ -366,15 +378,18 @@ def write_dag_dataset_alias_references( if alias_id not in orm_refs ) - def write_task_dataset_references( + def add_task_dataset_references( self, dags: dict[str, DagModel], datasets: dict[str, DatasetModel], *, session: Session, ) -> None: + # Optimization: No datasets means there are no references to update. + if not datasets: + return for dag_id, references in self.outlet_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].task_outlet_dataset_references = [] continue diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d682cbcfe60e9..9a42fe8d2abb4 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2665,13 +2665,14 @@ def bulk_write_to_db( DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), session=session) dataset_collection = DatasetCollection.collect(dags_by_ids) - orm_datasets = dataset_collection.write_datasets(session=session) - orm_dataset_aliases = dataset_collection.write_dataset_aliases(session=session) - session.flush() - dataset_collection.write_dag_dataset_references(orm_dags, orm_datasets, session=session) - dataset_collection.write_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) - dataset_collection.write_task_dataset_references(orm_dags, orm_datasets, session=session) + orm_datasets = dataset_collection.add_datasets(session=session) + orm_dataset_aliases = dataset_collection.add_dataset_aliases(session=session) + session.flush() # This populates id so we can create fks in later calls. + + dataset_collection.add_dag_dataset_references(orm_dags, orm_datasets, session=session) + dataset_collection.add_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) + dataset_collection.add_task_dataset_references(orm_dags, orm_datasets, session=session) session.flush() @provide_session From 7d1e961d617230c9bc1489141f0b4695c58f1d62 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 08:07:39 +0800 Subject: [PATCH 6/9] IN with constants is likely better than OR --- airflow/dag_processing/collection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f990d8473176c..9e6ac8cc7504c 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -31,7 +31,7 @@ import logging from typing import TYPE_CHECKING, NamedTuple -from sqlalchemy import func, or_, select +from sqlalchemy import func, select from sqlalchemy.orm import joinedload, load_only from sqlalchemy.sql import expression @@ -111,7 +111,7 @@ def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select: select(func.max(DagRun.execution_date).label("max_execution_date")) .where( DagRun.dag_id == dag_id, - or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED), + DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), ) .scalar_subquery() ) From 375520296556607264d71c3e41722922d46dc698 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 09:10:24 +0800 Subject: [PATCH 7/9] Flush to accomodate removel from create_datasets --- airflow/models/taskinstance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f93c90b638464..954e5ed4d0c80 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2925,6 +2925,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se dataset_obj = DatasetModel(uri=uri) dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session) self.log.warning("Created a new %r as it did not exist.", dataset_obj) + session.flush() dataset_objs_cache[uri] = dataset_obj for alias in alias_names: From ddadee389b7bb91575b653d40d7382dc38632085 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 13:23:39 +0800 Subject: [PATCH 8/9] Better operation collection name --- airflow/dag_processing/collection.py | 4 ++-- airflow/models/dag.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 9e6ac8cc7504c..9721abf7d7094 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -249,8 +249,8 @@ def _find_all_dataset_aliases(dags: Iterable[DAG]) -> Iterator[DatasetAlias]: yield obj -class DatasetCollection(NamedTuple): - """Datasets collected from DAGs.""" +class DatasetModelOperation(NamedTuple): + """Collect dataset/alias objects from DAGs and perform database operations for them.""" schedule_dataset_references: dict[str, list[Dataset]] schedule_dataset_alias_references: dict[str, list[DatasetAlias]] diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9a42fe8d2abb4..6447f7be154f8 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2644,7 +2644,7 @@ def bulk_write_to_db( return from airflow.dag_processing.collection import ( - DatasetCollection, + DatasetModelOperation, collect_orm_dags, create_orm_dag, update_orm_dags, @@ -2664,15 +2664,15 @@ def bulk_write_to_db( update_orm_dags(dags_by_ids, orm_dags, processor_subdir=processor_subdir, session=session) DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), session=session) - dataset_collection = DatasetCollection.collect(dags_by_ids) + dataset_op = DatasetModelOperation.collect(dags_by_ids) - orm_datasets = dataset_collection.add_datasets(session=session) - orm_dataset_aliases = dataset_collection.add_dataset_aliases(session=session) + orm_datasets = dataset_op.add_datasets(session=session) + orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session) session.flush() # This populates id so we can create fks in later calls. - dataset_collection.add_dag_dataset_references(orm_dags, orm_datasets, session=session) - dataset_collection.add_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) - dataset_collection.add_task_dataset_references(orm_dags, orm_datasets, session=session) + dataset_op.add_dag_dataset_references(orm_dags, orm_datasets, session=session) + dataset_op.add_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) + dataset_op.add_task_dataset_references(orm_dags, orm_datasets, session=session) session.flush() @provide_session From 149aed79be537760f41a868bf8485bac7a1e569e Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 19 Sep 2024 16:03:05 +0800 Subject: [PATCH 9/9] Fix docstring typo Co-authored-by: Ephraim Anierobi --- airflow/dag_processing/collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 9721abf7d7094..3f75e0b23bbfd 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -104,7 +104,7 @@ def create_orm_dag(dag: DAG, session: Session) -> DagModel: def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select: - """Build a select statement for retrieve the last automated run for each dag.""" + """Build a select statement to retrieve the last automated run for each dag.""" if len(dag_ids) == 1: # Index optimized fast path to avoid more complicated & slower groupby queryplan. (dag_id,) = dag_ids last_automated_runs_subq = (