Skip to content

Commit

Permalink
Fix accidental db tests in Task SDK (apache#44690)
Browse files Browse the repository at this point in the history
* Add hook to forbid db_test marker in SDK tests

* Remove db_test marker and fix tests

* Fix asset tests in task_sdk that use db

Mocks are added to avoid real db access. The db entry point is made
lazy, and function expand_alias_to_assets moved to airflow.models and
tested there properly with db access.

* Refetch ORM DAGs during collection

* Fix mock argument
  • Loading branch information
uranusjr authored and Ohashiro committed Dec 6, 2024
1 parent b80568f commit 9d349c5
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 243 deletions.
29 changes: 14 additions & 15 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,6 @@
log = logging.getLogger(__name__)


def _find_orm_dags(dag_ids: Iterable[str], *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
stmt = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.where(DagModel.dag_id.in_(dag_ids))
.options(joinedload(DagModel.schedule_asset_references))
.options(joinedload(DagModel.schedule_asset_alias_references))
.options(joinedload(DagModel.task_outlet_asset_references))
)
stmt = with_row_locks(stmt, of=DagModel, session=session)
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}


def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> Iterator[DagModel]:
for dag in dags:
orm_dag = DagModel(dag_id=dag.dag_id)
Expand Down Expand Up @@ -181,8 +167,21 @@ class DagModelOperation(NamedTuple):

dags: dict[str, DAG]

def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
stmt = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.where(DagModel.dag_id.in_(self.dags))
.options(joinedload(DagModel.schedule_asset_references))
.options(joinedload(DagModel.schedule_asset_alias_references))
.options(joinedload(DagModel.task_outlet_asset_references))
)
stmt = with_row_locks(stmt, of=DagModel, session=session)
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}

def add_dags(self, *, session: Session) -> dict[str, DagModel]:
orm_dags = _find_orm_dags(self.dags, session=session)
orm_dags = self.find_orm_dags(session=session)
orm_dags.update(
(model.dag_id, model)
for model in _create_orm_dags(
Expand Down
17 changes: 12 additions & 5 deletions airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterable

from sqlalchemy.orm import Session


def _fetch_active_assets_by_name(
names: Sequence[str],
session: Session,
) -> dict[str, Asset]:
def fetch_active_assets_by_name(names: Iterable[str], session: Session) -> dict[str, Asset]:
return {
asset_model[0].name: asset_model[0].to_public()
for asset_model in session.execute(
Expand All @@ -61,6 +58,16 @@ def _fetch_active_assets_by_name(
}


def expand_alias_to_assets(alias_name: str, session: Session) -> Iterable[AssetModel]:
"""Expand asset alias to resolved assets."""
asset_alias_obj = session.scalar(
select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1)
)
if asset_alias_obj:
return list(asset_alias_obj.assets)
return []


alias_association_table = Table(
"asset_alias_asset",
Base.metadata,
Expand Down
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ def bulk_write_to_db(
orm_asset_aliases = asset_op.add_asset_aliases(session=session)
session.flush() # This populates id so we can create fks in later calls.

orm_dags = dag_op.find_orm_dags(session=session) # Refetch so relationship is up to date.
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sqlalchemy import select

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, _fetch_active_assets_by_name
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
Expand Down Expand Up @@ -250,7 +250,7 @@ def __init__(self, inlets: list, *, session: Session) -> None:
_asset_ref_names.append(inlet.name)

if _asset_ref_names:
for asset_name, asset in _fetch_active_assets_by_name(_asset_ref_names, self._session).items():
for asset_name, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
self._assets[asset_name] = asset

def __iter__(self) -> Iterator[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@

if TYPE_CHECKING:
from airflow.auth.managers.models.resource_details import AssetDetails
from airflow.models.asset import expand_alias_to_assets
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAll,
AssetAny,
expand_alias_to_assets,
)
else:
if AIRFLOW_V_3_0_PLUS:
from airflow.auth.managers.models.resource_details import AssetDetails
from airflow.models.asset import expand_alias_to_assets
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAll,
AssetAny,
expand_alias_to_assets,
)
else:
# dataset is renamed to asset since Airflow 3.0
Expand Down
30 changes: 10 additions & 20 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import functools
import logging
import operator
import os
Expand All @@ -33,18 +34,14 @@
)

import attrs
from sqlalchemy import select

from airflow.serialization.dag_dependency import DagDependency
from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


Expand Down Expand Up @@ -502,21 +499,6 @@ def as_expression(self) -> dict[str, Any]:
return {"any": [o.as_expression() for o in self.objects]}


@provide_session
def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]:
"""Expand asset alias to resolved assets."""
from airflow.models.asset import AssetAliasModel

alias_name = alias.name if isinstance(alias, AssetAlias) else alias

asset_alias_obj = session.scalar(
select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1)
)
if asset_alias_obj:
return [asset.to_public() for asset in asset_alias_obj.assets]
return []


class AssetAliasCondition(AssetAny):
"""
Use to expand AssetAlias as AssetAny of its resolved Assets.
Expand All @@ -527,11 +509,19 @@ class AssetAliasCondition(AssetAny):
def __init__(self, name: str, group: str) -> None:
self.name = name
self.group = group
self.objects = expand_alias_to_assets(name)

def __repr__(self) -> str:
return f"AssetAliasCondition({', '.join(map(str, self.objects))})"

@functools.cached_property
def objects(self) -> list[BaseAsset]: # type: ignore[override]
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

def as_expression(self) -> Any:
"""
Serialize the asset alias into its scheduling expression.
Expand Down
12 changes: 7 additions & 5 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.utils.session import create_session

if TYPE_CHECKING:
from collections.abc import Collection, Iterator, Mapping
Expand Down Expand Up @@ -55,15 +54,16 @@ def _iter_kwargs(
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
from airflow.models.asset import _fetch_active_assets_by_name
from airflow.models.asset import fetch_active_assets_by_name
from airflow.utils.session import create_session

asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)]
asset_names = {asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)}
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)
asset_names.add(self._definition_name)

if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names, session)
active_assets = fetch_active_assets_by_name(asset_names, session)
else:
active_assets = {}
return dict(self._iter_kwargs(context, active_assets))
Expand Down Expand Up @@ -140,6 +140,8 @@ def __call__(self, f: Callable) -> AssetDefinition:
return AssetDefinition(
name=name,
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
function=f,
source=self,
)
25 changes: 25 additions & 0 deletions task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("norecursedirs", "tests/test_dags")


@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
if next(item.iter_markers(name="db_test"), None):
pytest.fail("Task SDK tests must not use database")


class LogCapture:
# Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict
entries: list[EventDict]
Expand Down Expand Up @@ -103,3 +109,22 @@ def captured_logs(request):
yield cap.entries
finally:
structlog.configure(processors=cur_processors)


@pytest.fixture(autouse=True, scope="session")
def _disable_ol_plugin():
# The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py
# 3.12+ issues a warning when os.fork happens. So for this plugin we disable it

# And we load plugins when setting the priorty_weight field
import airflow.plugins_manager

old = airflow.plugins_manager.plugins

assert old is None, "Plugins already loaded, too late to stop them being loaded!"

airflow.plugins_manager.plugins = []

yield

airflow.plugins_manager.plugins = None
Loading

0 comments on commit 9d349c5

Please sign in to comment.