Skip to content

Commit

Permalink
Refactor bulk_save_to_db (#42245)
Browse files Browse the repository at this point in the history
Co-authored-by: Ephraim Anierobi <splendidzigy24@gmail.com>
  • Loading branch information
uranusjr and ephraimbuddy authored Sep 20, 2024
1 parent 3464633 commit 8d816fb
Show file tree
Hide file tree
Showing 9 changed files with 515 additions and 366 deletions.
408 changes: 408 additions & 0 deletions airflow/dag_processing/collection.py

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def evaluate(self, statuses: dict[str, bool]) -> bool:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
raise NotImplementedError

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
raise NotImplementedError

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
Expand All @@ -212,6 +212,12 @@ class DatasetAlias(BaseDataset):

name: str

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())

def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
yield self.name, self

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Iterate a dataset alias as dag dependency.
Expand Down Expand Up @@ -294,7 +300,7 @@ def as_expression(self) -> Any:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
return iter(())

def evaluate(self, statuses: dict[str, bool]) -> bool:
Expand Down Expand Up @@ -339,7 +345,7 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield k, v
seen.add(k)

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
"""Filter dataest aliases in the condition."""
for o in self.objects:
yield from o.iter_dataset_aliases()
Expand Down Expand Up @@ -399,8 +405,8 @@ def as_expression(self) -> Any:
"""
return {"alias": self.name}

def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
yield DatasetAlias(self.name)
def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
yield self.name, DatasetAlias(self.name)

def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
"""
Expand Down
2 changes: 0 additions & 2 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session)
"""Create new datasets."""
for dataset_model in dataset_models:
session.add(dataset_model)
session.flush()

for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))

Expand Down
342 changes: 26 additions & 316 deletions airflow/models/dag.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,7 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
self.log.warning("Created a new %r as it did not exist.", dataset_obj)
session.flush()
dataset_objs_cache[uri] = dataset_obj

for alias in alias_names:
Expand Down
5 changes: 4 additions & 1 deletion airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from pendulum import DateTime

from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias
from airflow.serialization.dag_dependency import DagDependency
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -57,6 +57,9 @@ def evaluate(self, statuses: dict[str, bool]) -> bool:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())

def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
return iter(())

def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]:
return iter(())

Expand Down
64 changes: 64 additions & 0 deletions tests/dag_processing/test_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import warnings

from sqlalchemy.exc import SAWarning

from airflow.dag_processing.collection import _get_latest_runs_stmt


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt
2 changes: 1 addition & 1 deletion tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_dataset_iter_dataset_aliases():
DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")),
)
assert list(base_dataset.iter_dataset_aliases()) == [
DatasetAlias(f"example-alias-{i}") for i in range(1, 6)
(f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in range(1, 6)
]


Expand Down
41 changes: 0 additions & 41 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import os
import pickle
import re
import warnings
import weakref
from datetime import timedelta
from importlib import reload
Expand All @@ -37,7 +36,6 @@
import pytest
import time_machine
from sqlalchemy import inspect, select
from sqlalchemy.exc import SAWarning

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -3992,42 +3990,3 @@ def test_validate_setup_teardown_trigger_rule(self):
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
):
dag.validate_setup_teardown()


def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
"SELECT max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
"dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
"max(dag_run.logical_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt

0 comments on commit 8d816fb

Please sign in to comment.