diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 4f8d587727f14..1edcbb946dbb1 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -95,6 +95,9 @@ def __or__(self, other: BaseDatasetEventInput) -> DatasetAny: def __and__(self, other: BaseDatasetEventInput) -> DatasetAll: return DatasetAll(self, other) + def as_expression(self) -> dict[str, Any]: + raise NotImplementedError + def evaluate(self, statuses: dict[str, bool]) -> bool: raise NotImplementedError @@ -126,6 +129,11 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(self.uri) + def as_expression(self) -> dict[str, Any]: + if self.extra is None: + return {"uri": self.uri} + return {"uri": self.uri, "extra": self.extra} + def iter_datasets(self) -> Iterator[tuple[str, Dataset]]: yield self.uri, self @@ -141,6 +149,9 @@ class _DatasetBooleanCondition(BaseDatasetEventInput): def __init__(self, *objects: BaseDatasetEventInput) -> None: self.objects = objects + def as_expression(self) -> dict[str, Any]: + return {"objects": [o.as_expression() for o in self.objects]} + def evaluate(self, statuses: dict[str, bool]) -> bool: return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index ac38ae1eabf51..2ac247d73970e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3034,16 +3034,6 @@ def bulk_sync_to_db( ) return cls.bulk_write_to_db(dags=dags, session=session) - def simplify_dataset_expression(self, dataset_expression) -> dict | None: - """Simplifies a nested dataset expression into a 'any' or 'all' format with URIs.""" - if dataset_expression is None: - return None - if dataset_expression.get("__type") == "dataset": - return dataset_expression["__var"]["uri"] - - new_key = "any" if dataset_expression["__type"] == "dataset_any" else "all" - return {new_key: [self.simplify_dataset_expression(item) for item in dataset_expression["__var"]]} - @classmethod @provide_session def bulk_write_to_db( @@ -3063,8 +3053,6 @@ def bulk_write_to_db( if not dags: return - from airflow.serialization.serialized_objects import BaseSerialization # Avoid circular import. - log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} @@ -3129,9 +3117,10 @@ def bulk_write_to_db( ) orm_dag.schedule_interval = dag.schedule_interval orm_dag.timetable_description = dag.timetable.description - orm_dag.dataset_expression = dag.simplify_dataset_expression( - BaseSerialization.serialize(dag.dataset_triggers) - ) + if (dataset_triggers := dag.dataset_triggers) is None: + orm_dag.dataset_expression = None + else: + orm_dag.dataset_expression = dataset_triggers.as_expression() orm_dag.processor_subdir = processor_subdir