diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 66b24b0ad40a9..754f0830f6206 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -59,7 +59,6 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, - AssetAliasCondition, AssetAll, AssetAny, AssetRef, @@ -1108,9 +1107,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = AssetAliasCondition(name=obj.name, group=obj.group) - - deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) + deps.extend(obj.iter_dag_dependencies(source=task.dag_id, target="")) return deps @staticmethod diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index 57eec884b558a..20e8085fe0d37 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -19,7 +19,6 @@ from collections.abc import Collection, Sequence from typing import TYPE_CHECKING, Any -from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -162,20 +161,11 @@ class AssetTriggeredTimetable(_TrivialTimetable): :meta private: """ - UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias" - description: str = "Triggered by assets" def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets - if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition) - - if not next(self.asset_condition.iter_assets(), False): - self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY - else: - self._summary = "Asset" @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: @@ -185,7 +175,7 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable: @property def summary(self) -> str: - return self._summary + return "Asset" def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_asset_condition diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 3d41e87cf0152..046f836bb3608 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -337,7 +337,7 @@ def test_serialize_timetable(): Asset(name="2", uri="test://2", group="test-group"), AssetAlias(name="example-alias", group="test-group"), Asset(name="3", uri="test://3", group="test-group"), - AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + AssetAll(AssetAlias("another"), Asset("4")), ) dag = MagicMock() dag.timetable = AssetTriggeredTimetable(asset) @@ -354,7 +354,11 @@ def test_serialize_timetable(): "name": "2", "group": "test-group", }, - {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + { + "__type": DagAttributeTypes.ASSET_ALIAS, + "name": "example-alias", + "group": "test-group", + }, { "__type": DagAttributeTypes.ASSET, "extra": {}, @@ -365,7 +369,11 @@ def test_serialize_timetable(): { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ - {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + { + "__type": DagAttributeTypes.ASSET_ALIAS, + "name": "another", + "group": "", + }, { "__type": DagAttributeTypes.ASSET, "extra": {}, diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index ee5ca25c39e04..787757637a647 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,21 +17,12 @@ from __future__ import annotations -import functools import logging import operator import os import urllib.parse import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - NamedTuple, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload import attrs @@ -51,7 +42,6 @@ "Model", "AssetRef", "AssetAlias", - "AssetAliasCondition", "AssetAll", "AssetAny", ] @@ -407,24 +397,61 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) + def _resolve_assets(self) -> list[Asset]: + from airflow.models.asset import expand_alias_to_assets + from airflow.utils.session import create_session + + with create_session() as session: + asset_models = expand_alias_to_assets(self.name, session) + return [m.to_public() for m in asset_models] + + def as_expression(self) -> Any: + """ + Serialize the asset alias into its scheduling expression. + + :meta private: + """ + return {"alias": {"name": self.name, "group": self.group}} + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return any(x.evaluate(statuses=statuses) for x in self._resolve_assets()) + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, self - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: """ - Iterate an asset alias as dag dependency. + Iterate an asset alias and its resolved assets as dag dependency. :meta private: """ - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) + if not (resolved_assets := self._resolve_assets()): + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + return + for asset in resolved_assets: + asset_name = asset.name + # asset + yield DagDependency( + source=f"asset-alias:{self.name}" if source else "asset", + target="asset" if source else f"asset-alias:{self.name}", + dependency_type="asset", + dependency_id=asset_name, + ) + # asset alias + yield DagDependency( + source=source or f"asset:{asset_name}", + target=target or f"asset:{asset_name}", + dependency_type="asset-alias", + dependency_id=self.name, + ) class AssetAliasEvent(TypedDict): @@ -443,11 +470,7 @@ class _AssetBooleanCondition(BaseAsset): def __init__(self, *objects: BaseAsset) -> None: if not all(isinstance(o, BaseAsset) for o in objects): raise TypeError("expect asset expressions in condition") - - self.objects = [ - AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, AssetAlias) else obj - for obj in objects - ] + self.objects = objects def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) @@ -499,77 +522,6 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class AssetAliasCondition(AssetAny): - """ - Use to expand AssetAlias as AssetAny of its resolved Assets. - - :meta private: - """ - - def __init__(self, name: str, group: str) -> None: - self.name = name - self.group = group - - def __repr__(self) -> str: - return f"AssetAliasCondition({', '.join(map(str, self.objects))})" - - @functools.cached_property - def objects(self) -> list[BaseAsset]: # type: ignore[override] - from airflow.models.asset import expand_alias_to_assets - from airflow.utils.session import create_session - - with create_session() as session: - asset_models = expand_alias_to_assets(self.name, session) - return [m.to_public() for m in asset_models] - - def as_expression(self) -> Any: - """ - Serialize the asset alias into its scheduling expression. - - :meta private: - """ - return {"alias": {"name": self.name, "group": self.group}} - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - yield self.name, AssetAlias(self.name) - - def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: - """ - Iterate an asset alias and its resolved assets as dag dependency. - - :meta private: - """ - if self.objects: - for obj in self.objects: - asset = cast(Asset, obj) - asset_name = asset.name - # asset - yield DagDependency( - source=f"asset-alias:{self.name}" if source else "asset", - target="asset" if source else f"asset-alias:{self.name}", - dependency_type="asset", - dependency_id=asset_name, - ) - # asset alias - yield DagDependency( - source=source or f"asset:{asset_name}", - target=target or f"asset:{asset_name}", - dependency_type="asset-alias", - dependency_id=self.name, - ) - else: - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - @staticmethod - def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition: - return AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) - - class AssetAll(_AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 6f6d40fdbe9f9..55439fb1bcd84 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -27,7 +27,6 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, - AssetAliasCondition, AssetAll, AssetAny, BaseAsset, @@ -487,78 +486,38 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme): assert asset.normalized_uri == "valid_aip60_uri" -class FakeSession: - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - pass - - -FAKE_SESSION = FakeSession() - - -class TestAssetAliasCondition: +class TestAssetAlias: @pytest.fixture - def asset_model(self): + def asset(self): """Example asset links to asset alias resolved_asset_alias_2.""" - from airflow.models.asset import AssetModel - - return AssetModel( - id=1, - uri="test://asset1/", - name="test_name", - group="asset", - ) + return Asset(uri="test://asset1/", name="test_name", group="asset") @pytest.fixture def asset_alias_1(self): """Example asset alias links to no assets.""" - from airflow.models.asset import AssetAliasModel - - return AssetAliasModel(name="test_name", group="test") + asset_alias_1 = AssetAlias(name="test_name", group="test") + with mock.patch.object(asset_alias_1, "_resolve_assets", return_value=[]): + yield asset_alias_1 @pytest.fixture - def resolved_asset_alias_2(self, asset_model): - """Example asset alias links to asset asset_alias_1.""" - from airflow.models.asset import AssetAliasModel - - asset_alias_2 = AssetAliasModel(name="test_name_2") - asset_alias_2.assets.append(asset_model) - return asset_alias_2 - - def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): - for asset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = AssetAliasCondition.from_asset_alias(asset_alias) - assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}} - - @mock.patch("airflow.models.asset.expand_alias_to_assets") - @mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION) - def test_evalute_empty( - self, mock_create_session, mock_expand_alias_to_assets, asset_alias_1, asset_model - ): - mock_expand_alias_to_assets.return_value = [] - - cond = AssetAliasCondition.from_asset_alias(asset_alias_1) - assert cond.evaluate({asset_model.uri: True}) is False - - assert mock_expand_alias_to_assets.mock_calls == [mock.call(asset_alias_1.name, FAKE_SESSION)] - assert mock_create_session.mock_calls == [mock.call()] - - @mock.patch("airflow.models.asset.expand_alias_to_assets") - @mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION) - def test_evalute_resolved( - self, mock_create_session, mock_expand_alias_to_assets, resolved_asset_alias_2, asset_model - ): - mock_expand_alias_to_assets.return_value = [asset_model] - - cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) - assert cond.evaluate({asset_model.uri: True}) is True - - assert mock_expand_alias_to_assets.mock_calls == [ - mock.call(resolved_asset_alias_2.name, FAKE_SESSION), - ] - assert mock_create_session.mock_calls == [mock.call()] + def resolved_asset_alias_2(self, asset): + """Example asset alias links to asset.""" + asset_alias_2 = AssetAlias(name="test_name_2") + with mock.patch.object(asset_alias_2, "_resolve_assets", return_value=[asset]): + yield asset_alias_2 + + @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", "resolved_asset_alias_2"]) + def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name): + alias = request.getfixturevalue(alias_fixture_name) + assert alias.as_expression() == {"alias": {"name": alias.name, "group": alias.group}} + + def test_evalute_empty(self, asset_alias_1, asset): + assert asset_alias_1.evaluate({asset.uri: True}) is False + assert asset_alias_1._resolve_assets.mock_calls == [mock.call()] + + def test_evalute_resolved(self, resolved_asset_alias_2, asset): + assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True + assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()] class TestAssetSubclasses: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 384d76c7548b4..104e3c904940b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2250,7 +2250,6 @@ def test_dags_needing_dagruns_asset_aliases(self, dag_maker, session): # add queue records so we'll need a run dag_model = dag_maker.dag_model - asset_model: AssetModel = dag_model.schedule_assets[0] session.add(AssetDagRunQueue(asset_id=asset_model.id, target_dag_id=dag_model.dag_id)) session.flush() query, _ = DagModel.dags_needing_dagruns(session) diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index d456bf058bc3f..1dc8e6428d927 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -19,24 +19,21 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import Any import pytest from pendulum import DateTime from sqlalchemy.sql import select -from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.serialized_dag import SerializedDAG, SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny +from airflow.sdk.definitions.asset import Asset, AssetAll, AssetAny from airflow.timetables.assets import AssetOrTimeSchedule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType -if TYPE_CHECKING: - from sqlalchemy import Session - class MockTimetable(Timetable): """ @@ -274,25 +271,6 @@ def test_run_ordering_inheritance(asset_timetable: AssetOrTimeSchedule) -> None: assert asset_timetable.run_ordering == parent_run_ordering, "run_ordering does not match the parent class" -@pytest.mark.db_test -def test_summary(session: Session) -> None: - asset_model = AssetModel(uri="test_asset") - asset_alias_model = AssetAliasModel(name="test_asset_alias") - session.add_all([asset_model, asset_alias_model]) - session.commit() - - asset_alias = AssetAlias("test_asset_alias") - table = AssetTriggeredTimetable(asset_alias) - assert table.summary == "Unresolved AssetAlias" - - asset_alias_model.assets.append(asset_model) - session.add(asset_alias_model) - session.commit() - - table = AssetTriggeredTimetable(asset_alias) - assert table.summary == "Asset" - - @pytest.mark.db_test class TestAssetConditionWithTimetable: @pytest.fixture(autouse=True)