Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dataset serialization code #38089

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def __or__(self, other: BaseDatasetEventInput) -> DatasetAny:
def __and__(self, other: BaseDatasetEventInput) -> DatasetAll:
return DatasetAll(self, other)

def as_expression(self) -> dict[str, Any]:
raise NotImplementedError

def evaluate(self, statuses: dict[str, bool]) -> bool:
raise NotImplementedError

Expand Down Expand Up @@ -126,6 +129,11 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.uri)

def as_expression(self) -> dict[str, Any]:
if self.extra is None:
return {"uri": self.uri}
return {"uri": self.uri, "extra": self.extra}

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

Expand All @@ -141,6 +149,9 @@ class _DatasetBooleanCondition(BaseDatasetEventInput):
def __init__(self, *objects: BaseDatasetEventInput) -> None:
self.objects = objects

def as_expression(self) -> dict[str, Any]:
return {"objects": [o.as_expression() for o in self.objects]}

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)

Expand Down
19 changes: 4 additions & 15 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3034,16 +3034,6 @@ def bulk_sync_to_db(
)
return cls.bulk_write_to_db(dags=dags, session=session)

def simplify_dataset_expression(self, dataset_expression) -> dict | None:
"""Simplifies a nested dataset expression into a 'any' or 'all' format with URIs."""
if dataset_expression is None:
return None
if dataset_expression.get("__type") == "dataset":
return dataset_expression["__var"]["uri"]

new_key = "any" if dataset_expression["__type"] == "dataset_any" else "all"
return {new_key: [self.simplify_dataset_expression(item) for item in dataset_expression["__var"]]}

@classmethod
@provide_session
def bulk_write_to_db(
Expand All @@ -3063,8 +3053,6 @@ def bulk_write_to_db(
if not dags:
return

from airflow.serialization.serialized_objects import BaseSerialization # Avoid circular import.

log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}

Expand Down Expand Up @@ -3129,9 +3117,10 @@ def bulk_write_to_db(
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
orm_dag.dataset_expression = dag.simplify_dataset_expression(
BaseSerialization.serialize(dag.dataset_triggers)
)
if (dataset_triggers := dag.dataset_triggers) is None:
orm_dag.dataset_expression = None
else:
orm_dag.dataset_expression = dataset_triggers.as_expression()

orm_dag.processor_subdir = processor_subdir

Expand Down