Skip to content

Commit

Permalink
Ensure that you can create a second DAG whilst another one is already…
Browse files Browse the repository at this point in the history
… "active" (#44484)

Why would you want to do this? Who knows, maybe you are calling a dag factory
from inside a `with DAG` block. Either way, this exposed a subtle bug in
`TaskGroup.create_root()`.

This is the other half of the fix for the flakey tests fixed in #44480, and
after much digging with @kaxil and @potiuk we've finally worked out why it was
flakey:

It was the "Non-DB" test job that were faling sometimes, and those tests use
xdist to parallelize the tests. Couple that with the fact that
`get_serialized_fields()` caches the answer on the class object, the test
would only fail when nothing else in the current test process had previously
called `DAG.get_serialized_fields()`.

And to make this less likely to occur in future, the __serialized_fields is
moved to being created eagerly at parse time, no more lazy loaded cache!
  • Loading branch information
ashb authored Nov 29, 2024
1 parent 5dc14b5 commit 13e5464
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
58 changes: 29 additions & 29 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class DAG:
:param dag_display_name: The display name of the DAG which appears on the UI.
"""

__serialized_fields: ClassVar[frozenset[str] | None] = None
__serialized_fields: ClassVar[frozenset[str]]

# Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it
# doesn't correctly track/notice that they have default values (it gives errors about `Missing positional
Expand Down Expand Up @@ -964,34 +964,6 @@ def cli(self):
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls.__serialized_fields:
exclusion_list = {
"schedule_asset_references",
"schedule_asset_alias_references",
"task_outlet_asset_references",
"_old_context_manager_dags",
"safe_dag_id",
"last_loaded",
"user_defined_filters",
"user_defined_macros",
"partial",
"params",
"_log",
"task_dict",
"template_searchpath",
# "sla_miss_callback",
"on_success_callback",
"on_failure_callback",
"template_undefined",
"jinja_environment_kwargs",
# has_on_*_callback are only stored if the value is True, as the default is False
"has_on_success_callback",
"has_on_failure_callback",
"auto_register",
"fail_stop",
"schedule",
}
cls.__serialized_fields = frozenset(a.name for a in attrs.fields(cls)) - exclusion_list
return cls.__serialized_fields

def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
Expand Down Expand Up @@ -1030,6 +1002,34 @@ def _validate_owner_links(self, _, owner_links):
)


# Since we define all the attributes of the class with attrs, we can compute this statically at parse time
DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { # type: ignore[attr-defined]
"schedule_asset_references",
"schedule_asset_alias_references",
"task_outlet_asset_references",
"_old_context_manager_dags",
"safe_dag_id",
"last_loaded",
"user_defined_filters",
"user_defined_macros",
"partial",
"params",
"_log",
"task_dict",
"template_searchpath",
# "sla_miss_callback",
"on_success_callback",
"on_failure_callback",
"template_undefined",
"jinja_environment_kwargs",
# has_on_*_callback are only stored if the value is True, as the default is False
"has_on_success_callback",
"has_on_failure_callback",
"auto_register",
"fail_stop",
"schedule",
}

if TYPE_CHECKING:
# NOTE: Please keep the list of arguments in sync with DAG.__init__.
# Only exception: dag_id here should have a default value, but not in DAG.
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _check_for_group_id_collisions(self, add_suffix_on_collision: bool):
@classmethod
def create_root(cls, dag: DAG) -> TaskGroup:
"""Create a root TaskGroup with no group_id or parent."""
return cls(group_id=None, dag=dag)
return cls(group_id=None, dag=dag, parent_group=None)

@property
def node_id(self):
Expand Down
6 changes: 6 additions & 0 deletions task_sdk/tests/defintions/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,9 @@ def noop_pipeline(value): ...
# Test that if arg is not passed it raises a type error as expected.
with pytest.raises(TypeError):
noop_pipeline()

def test_create_dag_while_active_context(self):
"""Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``."""
with DAG(dag_id="simple_dag"):
DAG(dag_id="dag2")
# No asserts needed, it just needs to not fail

0 comments on commit 13e5464

Please sign in to comment.