From eff44cb12f3e7b59a8909df39b35870e87636dea Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 6 Dec 2024 11:05:55 +0800 Subject: [PATCH] Get rid of AssetAliasCondition (#44708) * Get rid of AssetAliasCondition Instead of having a separate class for condition evaluation, we can just use the main AssetAlias class directly. While it technically makes sense to subclass AssetAny, AssetAliasCondition does not really reuse much of its implementation, and we can just implement the missing methods ourselves instead. Whether the class actually is an AssetAny does not really make much of a difference. This actually allows us to simplify quite some code (including tests) a bit since we don't need to rewrap AssetAlias back and forth. * Fix serialization test * Does not need this call * Remove resolution-dependant timetable summary --- airflow/serialization/serialized_objects.py | 5 +- airflow/timetables/simple.py | 12 +- .../tests/openlineage/plugins/test_utils.py | 14 +- .../airflow/sdk/definitions/asset/__init__.py | 142 ++++++------------ task_sdk/tests/defintions/test_asset.py | 89 +++-------- tests/models/test_dag.py | 1 - tests/timetables/test_assets_timetable.py | 28 +--- 7 files changed, 87 insertions(+), 204 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 66b24b0ad40a9..754f0830f6206 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -59,7 +59,6 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, - AssetAliasCondition, AssetAll, AssetAny, AssetRef, @@ -1108,9 +1107,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = AssetAliasCondition(name=obj.name, group=obj.group) - - deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) + deps.extend(obj.iter_dag_dependencies(source=task.dag_id, target="")) return deps @staticmethod diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index 57eec884b558a..20e8085fe0d37 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -19,7 +19,6 @@ from collections.abc import Collection, Sequence from typing import TYPE_CHECKING, Any -from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -162,20 +161,11 @@ class AssetTriggeredTimetable(_TrivialTimetable): :meta private: """ - UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias" - description: str = "Triggered by assets" def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets - if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition) - - if not next(self.asset_condition.iter_assets(), False): - self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY - else: - self._summary = "Asset" @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: @@ -185,7 +175,7 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable: @property def summary(self) -> str: - return self._summary + return "Asset" def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_asset_condition diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 3d41e87cf0152..046f836bb3608 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -337,7 +337,7 @@ def test_serialize_timetable(): Asset(name="2", uri="test://2", group="test-group"), AssetAlias(name="example-alias", group="test-group"), Asset(name="3", uri="test://3", group="test-group"), - AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")), + AssetAll(AssetAlias("another"), Asset("4")), ) dag = MagicMock() dag.timetable = AssetTriggeredTimetable(asset) @@ -354,7 +354,11 @@ def test_serialize_timetable(): "name": "2", "group": "test-group", }, - {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + { + "__type": DagAttributeTypes.ASSET_ALIAS, + "name": "example-alias", + "group": "test-group", + }, { "__type": DagAttributeTypes.ASSET, "extra": {}, @@ -365,7 +369,11 @@ def test_serialize_timetable(): { "__type": DagAttributeTypes.ASSET_ALL, "objects": [ - {"__type": DagAttributeTypes.ASSET_ANY, "objects": []}, + { + "__type": DagAttributeTypes.ASSET_ALIAS, + "name": "another", + "group": "", + }, { "__type": DagAttributeTypes.ASSET, "extra": {}, diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index ee5ca25c39e04..787757637a647 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,21 +17,12 @@ from __future__ import annotations -import functools import logging import operator import os import urllib.parse import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - NamedTuple, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload import attrs @@ -51,7 +42,6 @@ "Model", "AssetRef", "AssetAlias", - "AssetAliasCondition", "AssetAll", "AssetAny", ] @@ -407,24 +397,61 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) + def _resolve_assets(self) -> list[Asset]: + from airflow.models.asset import expand_alias_to_assets + from airflow.utils.session import create_session + + with create_session() as session: + asset_models = expand_alias_to_assets(self.name, session) + return [m.to_public() for m in asset_models] + + def as_expression(self) -> Any: + """ + Serialize the asset alias into its scheduling expression. + + :meta private: + """ + return {"alias": {"name": self.name, "group": self.group}} + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return any(x.evaluate(statuses=statuses) for x in self._resolve_assets()) + def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: yield self.name, self - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: """ - Iterate an asset alias as dag dependency. + Iterate an asset alias and its resolved assets as dag dependency. :meta private: """ - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) + if not (resolved_assets := self._resolve_assets()): + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + return + for asset in resolved_assets: + asset_name = asset.name + # asset + yield DagDependency( + source=f"asset-alias:{self.name}" if source else "asset", + target="asset" if source else f"asset-alias:{self.name}", + dependency_type="asset", + dependency_id=asset_name, + ) + # asset alias + yield DagDependency( + source=source or f"asset:{asset_name}", + target=target or f"asset:{asset_name}", + dependency_type="asset-alias", + dependency_id=self.name, + ) class AssetAliasEvent(TypedDict): @@ -443,11 +470,7 @@ class _AssetBooleanCondition(BaseAsset): def __init__(self, *objects: BaseAsset) -> None: if not all(isinstance(o, BaseAsset) for o in objects): raise TypeError("expect asset expressions in condition") - - self.objects = [ - AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, AssetAlias) else obj - for obj in objects - ] + self.objects = objects def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) @@ -499,77 +522,6 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class AssetAliasCondition(AssetAny): - """ - Use to expand AssetAlias as AssetAny of its resolved Assets. - - :meta private: - """ - - def __init__(self, name: str, group: str) -> None: - self.name = name - self.group = group - - def __repr__(self) -> str: - return f"AssetAliasCondition({', '.join(map(str, self.objects))})" - - @functools.cached_property - def objects(self) -> list[BaseAsset]: # type: ignore[override] - from airflow.models.asset import expand_alias_to_assets - from airflow.utils.session import create_session - - with create_session() as session: - asset_models = expand_alias_to_assets(self.name, session) - return [m.to_public() for m in asset_models] - - def as_expression(self) -> Any: - """ - Serialize the asset alias into its scheduling expression. - - :meta private: - """ - return {"alias": {"name": self.name, "group": self.group}} - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - yield self.name, AssetAlias(self.name) - - def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: - """ - Iterate an asset alias and its resolved assets as dag dependency. - - :meta private: - """ - if self.objects: - for obj in self.objects: - asset = cast(Asset, obj) - asset_name = asset.name - # asset - yield DagDependency( - source=f"asset-alias:{self.name}" if source else "asset", - target="asset" if source else f"asset-alias:{self.name}", - dependency_type="asset", - dependency_id=asset_name, - ) - # asset alias - yield DagDependency( - source=source or f"asset:{asset_name}", - target=target or f"asset:{asset_name}", - dependency_type="asset-alias", - dependency_id=self.name, - ) - else: - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - @staticmethod - def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition: - return AssetAliasCondition(name=asset_alias.name, group=asset_alias.group) - - class AssetAll(_AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index 6f6d40fdbe9f9..55439fb1bcd84 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -27,7 +27,6 @@ from airflow.sdk.definitions.asset import ( Asset, AssetAlias, - AssetAliasCondition, AssetAll, AssetAny, BaseAsset, @@ -487,78 +486,38 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme): assert asset.normalized_uri == "valid_aip60_uri" -class FakeSession: - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - pass - - -FAKE_SESSION = FakeSession() - - -class TestAssetAliasCondition: +class TestAssetAlias: @pytest.fixture - def asset_model(self): + def asset(self): """Example asset links to asset alias resolved_asset_alias_2.""" - from airflow.models.asset import AssetModel - - return AssetModel( - id=1, - uri="test://asset1/", - name="test_name", - group="asset", - ) + return Asset(uri="test://asset1/", name="test_name", group="asset") @pytest.fixture def asset_alias_1(self): """Example asset alias links to no assets.""" - from airflow.models.asset import AssetAliasModel - - return AssetAliasModel(name="test_name", group="test") + asset_alias_1 = AssetAlias(name="test_name", group="test") + with mock.patch.object(asset_alias_1, "_resolve_assets", return_value=[]): + yield asset_alias_1 @pytest.fixture - def resolved_asset_alias_2(self, asset_model): - """Example asset alias links to asset asset_alias_1.""" - from airflow.models.asset import AssetAliasModel - - asset_alias_2 = AssetAliasModel(name="test_name_2") - asset_alias_2.assets.append(asset_model) - return asset_alias_2 - - def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): - for asset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = AssetAliasCondition.from_asset_alias(asset_alias) - assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}} - - @mock.patch("airflow.models.asset.expand_alias_to_assets") - @mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION) - def test_evalute_empty( - self, mock_create_session, mock_expand_alias_to_assets, asset_alias_1, asset_model - ): - mock_expand_alias_to_assets.return_value = [] - - cond = AssetAliasCondition.from_asset_alias(asset_alias_1) - assert cond.evaluate({asset_model.uri: True}) is False - - assert mock_expand_alias_to_assets.mock_calls == [mock.call(asset_alias_1.name, FAKE_SESSION)] - assert mock_create_session.mock_calls == [mock.call()] - - @mock.patch("airflow.models.asset.expand_alias_to_assets") - @mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION) - def test_evalute_resolved( - self, mock_create_session, mock_expand_alias_to_assets, resolved_asset_alias_2, asset_model - ): - mock_expand_alias_to_assets.return_value = [asset_model] - - cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2) - assert cond.evaluate({asset_model.uri: True}) is True - - assert mock_expand_alias_to_assets.mock_calls == [ - mock.call(resolved_asset_alias_2.name, FAKE_SESSION), - ] - assert mock_create_session.mock_calls == [mock.call()] + def resolved_asset_alias_2(self, asset): + """Example asset alias links to asset.""" + asset_alias_2 = AssetAlias(name="test_name_2") + with mock.patch.object(asset_alias_2, "_resolve_assets", return_value=[asset]): + yield asset_alias_2 + + @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", "resolved_asset_alias_2"]) + def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name): + alias = request.getfixturevalue(alias_fixture_name) + assert alias.as_expression() == {"alias": {"name": alias.name, "group": alias.group}} + + def test_evalute_empty(self, asset_alias_1, asset): + assert asset_alias_1.evaluate({asset.uri: True}) is False + assert asset_alias_1._resolve_assets.mock_calls == [mock.call()] + + def test_evalute_resolved(self, resolved_asset_alias_2, asset): + assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True + assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()] class TestAssetSubclasses: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 384d76c7548b4..104e3c904940b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2250,7 +2250,6 @@ def test_dags_needing_dagruns_asset_aliases(self, dag_maker, session): # add queue records so we'll need a run dag_model = dag_maker.dag_model - asset_model: AssetModel = dag_model.schedule_assets[0] session.add(AssetDagRunQueue(asset_id=asset_model.id, target_dag_id=dag_model.dag_id)) session.flush() query, _ = DagModel.dags_needing_dagruns(session) diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index d456bf058bc3f..1dc8e6428d927 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -19,24 +19,21 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import Any import pytest from pendulum import DateTime from sqlalchemy.sql import select -from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel +from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.serialized_dag import SerializedDAG, SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny +from airflow.sdk.definitions.asset import Asset, AssetAll, AssetAny from airflow.timetables.assets import AssetOrTimeSchedule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType -if TYPE_CHECKING: - from sqlalchemy import Session - class MockTimetable(Timetable): """ @@ -274,25 +271,6 @@ def test_run_ordering_inheritance(asset_timetable: AssetOrTimeSchedule) -> None: assert asset_timetable.run_ordering == parent_run_ordering, "run_ordering does not match the parent class" -@pytest.mark.db_test -def test_summary(session: Session) -> None: - asset_model = AssetModel(uri="test_asset") - asset_alias_model = AssetAliasModel(name="test_asset_alias") - session.add_all([asset_model, asset_alias_model]) - session.commit() - - asset_alias = AssetAlias("test_asset_alias") - table = AssetTriggeredTimetable(asset_alias) - assert table.summary == "Unresolved AssetAlias" - - asset_alias_model.assets.append(asset_model) - session.add(asset_alias_model) - session.commit() - - table = AssetTriggeredTimetable(asset_alias) - assert table.summary == "Asset" - - @pytest.mark.db_test class TestAssetConditionWithTimetable: @pytest.fixture(autouse=True)