diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index c41b2b2f3a7a0..f990d8473176c 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -283,7 +283,10 @@ def collect(cls, dags: dict[str, DAG]) -> Self: ) return coll - def write_datasets(self, *, session: Session) -> dict[str, DatasetModel]: + def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]: + # Optimization: skip all database calls if no datasets were collected. + if not self.datasets: + return {} orm_datasets: dict[str, DatasetModel] = { dm.uri: dm for dm in session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets))) @@ -305,7 +308,10 @@ def _resolve_dataset_addition() -> Iterator[DatasetModel]: dataset_manager.create_datasets(list(_resolve_dataset_addition()), session=session) return orm_datasets - def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: + def add_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasModel]: + # Optimization: skip all database calls if no dataset aliases were collected. + if not self.dataset_aliases: + return {} orm_aliases: dict[str, DatasetAliasModel] = { da.name: da for da in session.scalars( @@ -320,15 +326,18 @@ def write_dataset_aliases(self, *, session: Session) -> dict[str, DatasetAliasMo session.add(da) return orm_aliases - def write_dag_dataset_references( + def add_dag_dataset_references( self, dags: dict[str, DagModel], datasets: dict[str, DatasetModel], *, session: Session, ) -> None: + # Optimization: No datasets means there are no references to update. + if not datasets: + return for dag_id, references in self.schedule_dataset_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_references = [] continue @@ -343,15 +352,18 @@ def write_dag_dataset_references( if dataset_id not in orm_refs ) - def write_dag_dataset_alias_references( + def add_dag_dataset_alias_references( self, dags: dict[str, DagModel], aliases: dict[str, DatasetAliasModel], *, session: Session, ) -> None: + # Optimization: No aliases means there are no references to update. + if not aliases: + return for dag_id, references in self.schedule_dataset_alias_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].schedule_dataset_alias_references = [] continue @@ -366,15 +378,18 @@ def write_dag_dataset_alias_references( if alias_id not in orm_refs ) - def write_task_dataset_references( + def add_task_dataset_references( self, dags: dict[str, DagModel], datasets: dict[str, DatasetModel], *, session: Session, ) -> None: + # Optimization: No datasets means there are no references to update. + if not datasets: + return for dag_id, references in self.outlet_references.items(): - # Optimization: no references at all, just clear everything. + # Optimization: no references at all; this is faster than repeated delete(). if not references: dags[dag_id].task_outlet_dataset_references = [] continue diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d682cbcfe60e9..9a42fe8d2abb4 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2665,13 +2665,14 @@ def bulk_write_to_db( DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), session=session) dataset_collection = DatasetCollection.collect(dags_by_ids) - orm_datasets = dataset_collection.write_datasets(session=session) - orm_dataset_aliases = dataset_collection.write_dataset_aliases(session=session) - session.flush() - dataset_collection.write_dag_dataset_references(orm_dags, orm_datasets, session=session) - dataset_collection.write_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) - dataset_collection.write_task_dataset_references(orm_dags, orm_datasets, session=session) + orm_datasets = dataset_collection.add_datasets(session=session) + orm_dataset_aliases = dataset_collection.add_dataset_aliases(session=session) + session.flush() # This populates id so we can create fks in later calls. + + dataset_collection.add_dag_dataset_references(orm_dags, orm_datasets, session=session) + dataset_collection.add_dag_dataset_alias_references(orm_dags, orm_dataset_aliases, session=session) + dataset_collection.add_task_dataset_references(orm_dags, orm_datasets, session=session) session.flush() @provide_session