Skip to content

Commit

Permalink
Simplify dataset serialization code (apache#38089)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Mar 14, 2024
1 parent 0e2f2bc commit c14241b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
11 changes: 11 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def __or__(self, other: BaseDatasetEventInput) -> DatasetAny:
def __and__(self, other: BaseDatasetEventInput) -> DatasetAll:
return DatasetAll(self, other)

def as_expression(self) -> dict[str, Any]:
raise NotImplementedError

def evaluate(self, statuses: dict[str, bool]) -> bool:
raise NotImplementedError

Expand Down Expand Up @@ -126,6 +129,11 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.uri)

def as_expression(self) -> dict[str, Any]:
if self.extra is None:
return {"uri": self.uri}
return {"uri": self.uri, "extra": self.extra}

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

Expand All @@ -141,6 +149,9 @@ class _DatasetBooleanCondition(BaseDatasetEventInput):
def __init__(self, *objects: BaseDatasetEventInput) -> None:
self.objects = objects

def as_expression(self) -> dict[str, Any]:
return {"objects": [o.as_expression() for o in self.objects]}

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)

Expand Down
19 changes: 4 additions & 15 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3034,16 +3034,6 @@ def bulk_sync_to_db(
)
return cls.bulk_write_to_db(dags=dags, session=session)

def simplify_dataset_expression(self, dataset_expression) -> dict | None:
"""Simplifies a nested dataset expression into a 'any' or 'all' format with URIs."""
if dataset_expression is None:
return None
if dataset_expression.get("__type") == "dataset":
return dataset_expression["__var"]["uri"]

new_key = "any" if dataset_expression["__type"] == "dataset_any" else "all"
return {new_key: [self.simplify_dataset_expression(item) for item in dataset_expression["__var"]]}

@classmethod
@provide_session
def bulk_write_to_db(
Expand All @@ -3063,8 +3053,6 @@ def bulk_write_to_db(
if not dags:
return

from airflow.serialization.serialized_objects import BaseSerialization # Avoid circular import.

log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

Expand Down Expand Up @@ -3129,9 +3117,10 @@ def bulk_write_to_db(
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
orm_dag.dataset_expression = dag.simplify_dataset_expression(
BaseSerialization.serialize(dag.dataset_triggers)
)
if (dataset_triggers := dag.dataset_triggers) is None:
orm_dag.dataset_expression = None
else:
orm_dag.dataset_expression = dataset_triggers.as_expression()

orm_dag.processor_subdir = processor_subdir

Expand Down

0 comments on commit c14241b

Please sign in to comment.