Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bulk_save_to_db #42245

Merged
merged 9 commits into from
Sep 20, 2024
Merged
408 changes: 408 additions & 0 deletions airflow/dag_processing/collection.py

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def evaluate(self, statuses: dict[str, bool]) -> bool:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
raise NotImplementedError

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
raise NotImplementedError

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
Expand All @@ -212,6 +212,12 @@ class DatasetAlias(BaseDataset):

name: str

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())

def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
yield self.name, self

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Iterate a dataset alias as dag dependency.
Expand Down Expand Up @@ -294,7 +300,7 @@ def as_expression(self) -> Any:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
return iter(())

def evaluate(self, statuses: dict[str, bool]) -> bool:
Expand Down Expand Up @@ -339,7 +345,7 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield k, v
seen.add(k)

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
"""Filter dataest aliases in the condition."""
for o in self.objects:
yield from o.iter_dataset_aliases()
Expand Down Expand Up @@ -399,8 +405,8 @@ def as_expression(self) -> Any:
"""
return {"alias": self.name}

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
yield DatasetAlias(self.name)
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
yield self.name, DatasetAlias(self.name)

def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
"""
Expand Down
2 changes: 0 additions & 2 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session)
"""Create new datasets."""
for dataset_model in dataset_models:
session.add(dataset_model)
session.flush()

for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))

Expand Down
342 changes: 26 additions & 316 deletions airflow/models/dag.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
session.flush()
dataset_objs_cache[uri] = dataset_obj

for alias in alias_names:
Expand Down
5 changes: 4 additions & 1 deletion airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from pendulum import DateTime

from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias
from airflow.serialization.dag_dependency import DagDependency
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -57,6 +57,9 @@ def evaluate(self, statuses: dict[str, bool]) -> bool:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())

def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
return iter(())

def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]:
return iter(())

Expand Down
64 changes: 64 additions & 0 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import warnings

from sqlalchemy.exc import SAWarning

from airflow.dag_processing.collection import _get_latest_runs_stmt


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt
2 changes: 1 addition & 1 deletion tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_dataset_iter_dataset_aliases():
DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")),
)
assert list(base_dataset.iter_dataset_aliases()) == [
DatasetAlias(f"example-alias-{i}") for i in range(1, 6)
(f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in range(1, 6)
]


Expand Down
41 changes: 0 additions & 41 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import os
import pickle
import re
import warnings
import weakref
from datetime import timedelta
from importlib import reload
Expand All @@ -37,7 +36,6 @@
import pytest
import time_machine
from sqlalchemy import inspect, select
from sqlalchemy.exc import SAWarning

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -3992,42 +3990,3 @@ def test_validate_setup_teardown_trigger_rule(self):
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
):
dag.validate_setup_teardown()


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt