diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 705cc797ed11b..87227f9d8c41a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2748,9 +2748,9 @@ def _register_asset_changes_int( ) elif isinstance(obj, AssetAlias): for asset_alias_event in events[obj].asset_alias_events: - asset_alias_name = asset_alias_event["source_alias_name"] - asset_uri = asset_alias_event["dest_asset_uri"] - frozen_extra = frozenset(asset_alias_event["extra"].items()) + asset_alias_name = asset_alias_event.source_alias_name + asset_uri = asset_alias_event.dest_asset_uri + frozen_extra = frozenset(asset_alias_event.extra.items()) asset_alias_names[(asset_uri, frozen_extra)].add(asset_alias_name) asset_models: dict[str, AssetModel] = { diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 754f0830f6206..3f07f387c4961 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -85,6 +85,7 @@ from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.code_utils import get_python_source from airflow.utils.context import ( + AssetAliasEvent, ConnectionAccessor, Context, OutletEventAccessor, @@ -305,7 +306,7 @@ def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]: raw_key = var.raw_key return { "extra": var.extra, - "asset_alias_events": var.asset_alias_events, + "asset_alias_events": [attrs.asdict(cast(attrs.AttrsInstance, e)) for e in var.asset_alias_events], "raw_key": BaseSerialization.serialize(raw_key), } @@ -316,7 +317,7 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: outlet_event_accessor = OutletEventAccessor( extra=var["extra"], raw_key=BaseSerialization.deserialize(var["raw_key"]), - asset_alias_events=asset_alias_events, + asset_alias_events=[AssetAliasEvent(**e) for e in asset_alias_events], ) return outlet_event_accessor diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 72521fadaa00b..8a8178eda473a 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -36,12 +36,7 @@ from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, fetch_active_assets_by_name -from airflow.sdk.definitions.asset import ( - Asset, - AssetAlias, - AssetAliasEvent, - AssetRef, -) +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef from airflow.sdk.definitions.asset.metadata import extract_event_key from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET @@ -145,6 +140,19 @@ def get(self, key: str, default_conn: Any = None) -> Any: return default_conn +@attrs.define() +class AssetAliasEvent: + """ + Represeation of asset event to be triggered by an asset alias. + + :meta private: + """ + + source_alias_name: str + dest_asset_uri: str + extra: dict[str, Any] + + @attrs.define() class OutletEventAccessor: """ @@ -173,9 +181,7 @@ def add(self, asset: Asset | str, extra: dict[str, Any] | None = None) -> None: else: return - event = AssetAliasEvent( - source_alias_name=asset_alias_name, dest_asset_uri=asset_uri, extra=extra or {} - ) + event = AssetAliasEvent(asset_alias_name, asset_uri, extra=extra or {}) self.asset_alias_events.append(event) diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 15511b375bd76..19b6500809931 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -39,7 +39,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import TypedDict @@ -57,6 +57,12 @@ class VariableAccessor: class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... +class AssetAliasEvent: + source_alias_name: str + dest_asset_uri: str + extra: dict[str, Any] + def __init__(self, source_alias_name: str, dest_asset_uri: str, extra: dict[str, Any]) -> None: ... + class OutletEventAccessor: def __init__( self, diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index 049c4ac40b997..0ab6cd7a4beee 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -29,24 +29,12 @@ if TYPE_CHECKING: from airflow.auth.managers.models.resource_details import AssetDetails from airflow.models.asset import expand_alias_to_assets - from airflow.sdk.definitions.asset import ( - Asset, - AssetAlias, - AssetAliasEvent, - AssetAll, - AssetAny, - ) + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny else: if AIRFLOW_V_3_0_PLUS: from airflow.auth.managers.models.resource_details import AssetDetails from airflow.models.asset import expand_alias_to_assets - from airflow.sdk.definitions.asset import ( - Asset, - AssetAlias, - AssetAliasEvent, - AssetAll, - AssetAny, - ) + from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny else: # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset @@ -63,7 +51,6 @@ if AIRFLOW_V_2_10_PLUS: from airflow.datasets import ( DatasetAlias as AssetAlias, - DatasetAliasEvent as AssetAliasEvent, expand_alias_to_datasets as expand_alias_to_assets, ) @@ -71,7 +58,6 @@ __all__ = [ "Asset", "AssetAlias", - "AssetAliasEvent", "AssetAll", "AssetAny", "AssetDetails", diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 787757637a647..7ea61e905f509 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -27,7 +27,6 @@ import attrs from airflow.serialization.dag_dependency import DagDependency -from airflow.typing_compat import TypedDict if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -454,14 +453,6 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat ) -class AssetAliasEvent(TypedDict): - """A represeation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_uri: str - extra: dict[str, Any] - - class _AssetBooleanCondition(BaseAsset): """Base class for asset boolean logic.""" @@ -476,7 +467,7 @@ def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: - seen = set() # We want to keep the first instance. + seen: set[AssetUniqueKey] = set() # We want to keep the first instance. for o in self.objects: for k, v in o.iter_assets(): if k in seen: @@ -486,8 +477,13 @@ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: """Filter asset aliases in the condition.""" + seen: set[str] = set() # We want to keep the first instance. for o in self.objects: - yield from o.iter_asset_aliases() + for k, v in o.iter_asset_aliases(): + if k in seen: + continue + yield k, v + seen.add(k) def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 3e8e844528822..5c53304e7af32 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -49,7 +49,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic @@ -60,7 +60,7 @@ from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.context import OutletEventAccessor, OutletEventAccessors +from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 0e7309075b38c..538b74119204b 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -21,8 +21,8 @@ import pytest from airflow.models.asset import AssetAliasModel, AssetModel -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent -from airflow.utils.context import OutletEventAccessor, OutletEventAccessors +from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.utils.context import AssetAliasEvent, OutletEventAccessor, OutletEventAccessors class TestOutletEventAccessor: