Skip to content

Commit

Permalink
Get rid of AssetAliasCondition (apache#44708)
Browse files Browse the repository at this point in the history
* Get rid of AssetAliasCondition

Instead of having a separate class for condition evaluation, we can just
use the main AssetAlias class directly. While it technically makes sense
to subclass AssetAny, AssetAliasCondition does not really reuse much of
its implementation, and we can just implement the missing methods
ourselves instead. Whether the class actually is an AssetAny does not
really make much of a difference.

This actually allows us to simplify quite some code (including tests) a
bit since we don't need to rewrap AssetAlias back and forth.

* Fix serialization test

* Does not need this call

* Remove resolution-dependant timetable summary
  • Loading branch information
uranusjr authored and Ohashiro committed Dec 6, 2024
1 parent 9d349c5 commit eff44cb
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 204 deletions.
5 changes: 1 addition & 4 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasCondition,
AssetAll,
AssetAny,
AssetRef,
Expand Down Expand Up @@ -1108,9 +1107,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]:
)
)
elif isinstance(obj, AssetAlias):
cond = AssetAliasCondition(name=obj.name, group=obj.group)

deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
deps.extend(obj.iter_dag_dependencies(source=task.dag_id, target=""))
return deps

@staticmethod
Expand Down
12 changes: 1 addition & 11 deletions airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any

from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.utils import timezone

Expand Down Expand Up @@ -162,20 +161,11 @@ class AssetTriggeredTimetable(_TrivialTimetable):
:meta private:
"""

UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias"

description: str = "Triggered by assets"

def __init__(self, assets: BaseAsset) -> None:
super().__init__()
self.asset_condition = assets
if isinstance(self.asset_condition, AssetAlias):
self.asset_condition = AssetAliasCondition.from_asset_alias(self.asset_condition)

if not next(self.asset_condition.iter_assets(), False):
self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
else:
self._summary = "Asset"

@classmethod
def deserialize(cls, data: dict[str, Any]) -> Timetable:
Expand All @@ -185,7 +175,7 @@ def deserialize(cls, data: dict[str, Any]) -> Timetable:

@property
def summary(self) -> str:
return self._summary
return "Asset"

def serialize(self) -> dict[str, Any]:
from airflow.serialization.serialized_objects import encode_asset_condition
Expand Down
14 changes: 11 additions & 3 deletions providers/tests/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_serialize_timetable():
Asset(name="2", uri="test://2", group="test-group"),
AssetAlias(name="example-alias", group="test-group"),
Asset(name="3", uri="test://3", group="test-group"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
AssetAll(AssetAlias("another"), Asset("4")),
)
dag = MagicMock()
dag.timetable = AssetTriggeredTimetable(asset)
Expand All @@ -354,7 +354,11 @@ def test_serialize_timetable():
"name": "2",
"group": "test-group",
},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{
"__type": DagAttributeTypes.ASSET_ALIAS,
"name": "example-alias",
"group": "test-group",
},
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
Expand All @@ -365,7 +369,11 @@ def test_serialize_timetable():
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{
"__type": DagAttributeTypes.ASSET_ALIAS,
"name": "another",
"group": "",
},
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
Expand Down
142 changes: 47 additions & 95 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,12 @@

from __future__ import annotations

import functools
import logging
import operator
import os
import urllib.parse
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
NamedTuple,
cast,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload

import attrs

Expand All @@ -51,7 +42,6 @@
"Model",
"AssetRef",
"AssetAlias",
"AssetAliasCondition",
"AssetAll",
"AssetAny",
]
Expand Down Expand Up @@ -407,24 +397,61 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier)

def _resolve_assets(self) -> list[Asset]:
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

def as_expression(self) -> Any:
"""
Serialize the asset alias into its scheduling expression.
:meta private:
"""
return {"alias": {"name": self.name, "group": self.group}}

def evaluate(self, statuses: dict[str, bool]) -> bool:
return any(x.evaluate(statuses=statuses) for x in self._resolve_assets())

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
yield self.name, self

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
"""
Iterate an asset alias as dag dependency.
Iterate an asset alias and its resolved assets as dag dependency.
:meta private:
"""
yield DagDependency(
source=source or "asset-alias",
target=target or "asset-alias",
dependency_type="asset-alias",
dependency_id=self.name,
)
if not (resolved_assets := self._resolve_assets()):
yield DagDependency(
source=source or "asset-alias",
target=target or "asset-alias",
dependency_type="asset-alias",
dependency_id=self.name,
)
return
for asset in resolved_assets:
asset_name = asset.name
# asset
yield DagDependency(
source=f"asset-alias:{self.name}" if source else "asset",
target="asset" if source else f"asset-alias:{self.name}",
dependency_type="asset",
dependency_id=asset_name,
)
# asset alias
yield DagDependency(
source=source or f"asset:{asset_name}",
target=target or f"asset:{asset_name}",
dependency_type="asset-alias",
dependency_id=self.name,
)


class AssetAliasEvent(TypedDict):
Expand All @@ -443,11 +470,7 @@ class _AssetBooleanCondition(BaseAsset):
def __init__(self, *objects: BaseAsset) -> None:
if not all(isinstance(o, BaseAsset) for o in objects):
raise TypeError("expect asset expressions in condition")

self.objects = [
AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, AssetAlias) else obj
for obj in objects
]
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)
Expand Down Expand Up @@ -499,77 +522,6 @@ def as_expression(self) -> dict[str, Any]:
return {"any": [o.as_expression() for o in self.objects]}


class AssetAliasCondition(AssetAny):
"""
Use to expand AssetAlias as AssetAny of its resolved Assets.
:meta private:
"""

def __init__(self, name: str, group: str) -> None:
self.name = name
self.group = group

def __repr__(self) -> str:
return f"AssetAliasCondition({', '.join(map(str, self.objects))})"

@functools.cached_property
def objects(self) -> list[BaseAsset]: # type: ignore[override]
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

def as_expression(self) -> Any:
"""
Serialize the asset alias into its scheduling expression.
:meta private:
"""
return {"alias": {"name": self.name, "group": self.group}}

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
yield self.name, AssetAlias(self.name)

def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]:
"""
Iterate an asset alias and its resolved assets as dag dependency.
:meta private:
"""
if self.objects:
for obj in self.objects:
asset = cast(Asset, obj)
asset_name = asset.name
# asset
yield DagDependency(
source=f"asset-alias:{self.name}" if source else "asset",
target="asset" if source else f"asset-alias:{self.name}",
dependency_type="asset",
dependency_id=asset_name,
)
# asset alias
yield DagDependency(
source=source or f"asset:{asset_name}",
target=target or f"asset:{asset_name}",
dependency_type="asset-alias",
dependency_id=self.name,
)
else:
yield DagDependency(
source=source or "asset-alias",
target=target or "asset-alias",
dependency_type="asset-alias",
dependency_id=self.name,
)

@staticmethod
def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition:
return AssetAliasCondition(name=asset_alias.name, group=asset_alias.group)


class AssetAll(_AssetBooleanCondition):
"""Use to combine assets schedule references in an "or" relationship."""

Expand Down
89 changes: 24 additions & 65 deletions task_sdk/tests/defintions/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasCondition,
AssetAll,
AssetAny,
BaseAsset,
Expand Down Expand Up @@ -487,78 +486,38 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme):
assert asset.normalized_uri == "valid_aip60_uri"


class FakeSession:
def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
pass


FAKE_SESSION = FakeSession()


class TestAssetAliasCondition:
class TestAssetAlias:
@pytest.fixture
def asset_model(self):
def asset(self):
"""Example asset links to asset alias resolved_asset_alias_2."""
from airflow.models.asset import AssetModel

return AssetModel(
id=1,
uri="test://asset1/",
name="test_name",
group="asset",
)
return Asset(uri="test://asset1/", name="test_name", group="asset")

@pytest.fixture
def asset_alias_1(self):
"""Example asset alias links to no assets."""
from airflow.models.asset import AssetAliasModel

return AssetAliasModel(name="test_name", group="test")
asset_alias_1 = AssetAlias(name="test_name", group="test")
with mock.patch.object(asset_alias_1, "_resolve_assets", return_value=[]):
yield asset_alias_1

@pytest.fixture
def resolved_asset_alias_2(self, asset_model):
"""Example asset alias links to asset asset_alias_1."""
from airflow.models.asset import AssetAliasModel

asset_alias_2 = AssetAliasModel(name="test_name_2")
asset_alias_2.assets.append(asset_model)
return asset_alias_2

def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
for asset_alias in (asset_alias_1, resolved_asset_alias_2):
cond = AssetAliasCondition.from_asset_alias(asset_alias)
assert cond.as_expression() == {"alias": {"name": asset_alias.name, "group": asset_alias.group}}

@mock.patch("airflow.models.asset.expand_alias_to_assets")
@mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION)
def test_evalute_empty(
self, mock_create_session, mock_expand_alias_to_assets, asset_alias_1, asset_model
):
mock_expand_alias_to_assets.return_value = []

cond = AssetAliasCondition.from_asset_alias(asset_alias_1)
assert cond.evaluate({asset_model.uri: True}) is False

assert mock_expand_alias_to_assets.mock_calls == [mock.call(asset_alias_1.name, FAKE_SESSION)]
assert mock_create_session.mock_calls == [mock.call()]

@mock.patch("airflow.models.asset.expand_alias_to_assets")
@mock.patch("airflow.utils.session.create_session", return_value=FAKE_SESSION)
def test_evalute_resolved(
self, mock_create_session, mock_expand_alias_to_assets, resolved_asset_alias_2, asset_model
):
mock_expand_alias_to_assets.return_value = [asset_model]

cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2)
assert cond.evaluate({asset_model.uri: True}) is True

assert mock_expand_alias_to_assets.mock_calls == [
mock.call(resolved_asset_alias_2.name, FAKE_SESSION),
]
assert mock_create_session.mock_calls == [mock.call()]
def resolved_asset_alias_2(self, asset):
"""Example asset alias links to asset."""
asset_alias_2 = AssetAlias(name="test_name_2")
with mock.patch.object(asset_alias_2, "_resolve_assets", return_value=[asset]):
yield asset_alias_2

@pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", "resolved_asset_alias_2"])
def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name):
alias = request.getfixturevalue(alias_fixture_name)
assert alias.as_expression() == {"alias": {"name": alias.name, "group": alias.group}}

def test_evalute_empty(self, asset_alias_1, asset):
assert asset_alias_1.evaluate({asset.uri: True}) is False
assert asset_alias_1._resolve_assets.mock_calls == [mock.call()]

def test_evalute_resolved(self, resolved_asset_alias_2, asset):
assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True
assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()]


class TestAssetSubclasses:
Expand Down
Loading

0 comments on commit eff44cb

Please sign in to comment.