diff --git a/airflow/cli/commands/remote_commands/dag_command.py b/airflow/cli/commands/remote_commands/dag_command.py index 1e922029abff25..acd94ff93e0784 100644 --- a/airflow/cli/commands/remote_commands/dag_command.py +++ b/airflow/cli/commands/remote_commands/dag_command.py @@ -37,9 +37,9 @@ from airflow.jobs.job import Job from airflow.models import DagBag, DagModel, DagRun, TaskInstance from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils import cli as cli_utils, timezone from airflow.utils.cli import get_dag, process_subdir, suppress_logs_and_warning -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.dot_renderer import render_dag, render_dag_dependencies from airflow.utils.helpers import ask_yesno from airflow.utils.providers_configuration_loader import providers_configuration_loaded diff --git a/airflow/task/standard_task_runner.py b/airflow/task/standard_task_runner.py index b4d05fc4753890..000e4c27501153 100644 --- a/airflow/task/standard_task_runner.py +++ b/airflow/task/standard_task_runner.py @@ -33,10 +33,10 @@ from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.models.taskinstance import TaskReturnCode +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.settings import CAN_FORK from airflow.stats import Stats from airflow.utils.configuration import tmp_configuration_copy -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.platform import IS_WINDOWS, getuser diff --git a/docs/apache-airflow/howto/dynamic-dag-generation.rst b/docs/apache-airflow/howto/dynamic-dag-generation.rst index 9aa988f28bdb10..f3b3e25f0c3a6e 100644 --- a/docs/apache-airflow/howto/dynamic-dag-generation.rst +++ b/docs/apache-airflow/howto/dynamic-dag-generation.rst @@ -207,7 +207,7 @@ of the context are set to ``None``. :emphasize-lines: 4,8,9 from airflow.models.dag import DAG - from airflow.utils.dag_parsing_context import get_parsing_context + from airflow.sdk import get_parsing_context current_dag_id = get_parsing_context().dag_id diff --git a/providers/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/src/airflow/providers/celery/executors/celery_executor_utils.py index 65f8dfbe5b85ce..fda7bd10e0ea6f 100644 --- a/providers/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -44,8 +44,8 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.stats import Stats -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.providers_configuration_loader import providers_configuration_loaded diff --git a/scripts/cov/core_coverage.py b/scripts/cov/core_coverage.py index 3f64167d3df756..0366fdbed200a5 100644 --- a/scripts/cov/core_coverage.py +++ b/scripts/cov/core_coverage.py @@ -104,7 +104,6 @@ "airflow/utils/code_utils.py", "airflow/utils/context.py", "airflow/utils/dag_cycle_tester.py", - "airflow/utils/dag_parsing_context.py", "airflow/utils/dates.py", "airflow/utils/db.py", "airflow/utils/db_cleanup.py", diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index e50b475b006a37..1bd0358a63c7e8 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -27,6 +27,7 @@ "dag", "Connection", "get_current_context", + "get_parsing_context", "__version__", ] @@ -35,7 +36,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection - from airflow.sdk.definitions.context import get_current_context + from airflow.sdk.definitions.context import get_current_context, get_parsing_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.taskgroup import TaskGroup @@ -50,6 +51,7 @@ "Connection": ".definitions.connection", "Variable": ".definitions.variable", "get_current_context": ".definitions.context", + "get_parsing_context": ".definitions.context", } diff --git a/airflow/utils/dag_parsing_context.py b/task_sdk/src/airflow/sdk/definitions/_internal/dag_parsing_context.py similarity index 65% rename from airflow/utils/dag_parsing_context.py rename to task_sdk/src/airflow/sdk/definitions/_internal/dag_parsing_context.py index 27d2f8cab31c98..cad4f991859235 100644 --- a/airflow/utils/dag_parsing_context.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/dag_parsing_context.py @@ -18,23 +18,8 @@ import os from contextlib import contextmanager -from typing import NamedTuple - -class AirflowParsingContext(NamedTuple): - """ - Context of parsing for the DAG. - - If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to - execute. You can use these for optimizing dynamically generated DAG files. - """ - - dag_id: str | None - task_id: str | None - - -_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID" -_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID" +from airflow.sdk.definitions.context import _AIRFLOW_PARSING_CONTEXT_DAG_ID, _AIRFLOW_PARSING_CONTEXT_TASK_ID @contextmanager @@ -50,11 +35,3 @@ def _airflow_parsing_context_manager(dag_id: str | None = None, task_id: str | N os.environ[_AIRFLOW_PARSING_CONTEXT_TASK_ID] = old_task_id if old_dag_id is not None: os.environ[_AIRFLOW_PARSING_CONTEXT_DAG_ID] = old_dag_id - - -def get_parsing_context() -> AirflowParsingContext: - """Return the current (DAG) parsing context info.""" - return AirflowParsingContext( - dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID), - task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID), - ) diff --git a/task_sdk/src/airflow/sdk/definitions/context.py b/task_sdk/src/airflow/sdk/definitions/context.py index a6bbc88ef86f53..46a92ec2beb149 100644 --- a/task_sdk/src/airflow/sdk/definitions/context.py +++ b/task_sdk/src/airflow/sdk/definitions/context.py @@ -17,7 +17,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypedDict +import os +from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict if TYPE_CHECKING: # TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x? @@ -105,3 +106,27 @@ def my_task(): from airflow.sdk.definitions._internal.contextmanager import _get_current_context return _get_current_context() + + +class AirflowParsingContext(NamedTuple): + """ + Context of parsing for the DAG. + + If these values are not None, they will contain the specific DAG and Task ID that Airflow is requesting to + execute. You can use these for optimizing dynamically generated DAG files. + """ + + dag_id: str | None + task_id: str | None + + +_AIRFLOW_PARSING_CONTEXT_DAG_ID = "_AIRFLOW_PARSING_CONTEXT_DAG_ID" +_AIRFLOW_PARSING_CONTEXT_TASK_ID = "_AIRFLOW_PARSING_CONTEXT_TASK_ID" + + +def get_parsing_context() -> AirflowParsingContext: + """Return the current (DAG) parsing context info.""" + return AirflowParsingContext( + dag_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_DAG_ID), + task_id=os.environ.get(_AIRFLOW_PARSING_CONTEXT_TASK_ID), + ) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 02ed37514886f7..a483e9993cdf67 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -33,6 +33,7 @@ from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState, TIRunContext +from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, @@ -406,8 +407,8 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: setproctitle(f"airflow worker -- {msg.ti.id}") log = structlog.get_logger(logger_name="task") - # TODO: set the "magic loop" context vars for parsing - ti = parse(msg) + with _airflow_parsing_context_manager(dag_id="msg.ti.dag_id", task_id=msg.ti.task_id): + ti = parse(msg) log.debug("DAG file parsed", file=msg.dag_rel_path) else: raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") diff --git a/task_sdk/tests/dags/dag_parsing_context.py b/task_sdk/tests/dags/dag_parsing_context.py new file mode 100644 index 00000000000000..f7a9dd42290a6a --- /dev/null +++ b/task_sdk/tests/dags/dag_parsing_context.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.sdk import DAG, BaseOperator, get_parsing_context + +DAG_ID = "dag_parsing_context_test" + +current_dag_id = get_parsing_context().dag_id + +with DAG( + DAG_ID, + start_date=datetime(2024, 2, 21), + schedule=None, +) as the_dag: + BaseOperator(task_id="visible_task") + + if current_dag_id == DAG_ID: + # this task will be invisible if the DAG ID is not properly set in the parsing context. + BaseOperator(task_id="conditional_task") diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index e062c0ef33d81f..f716317ad24fcf 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -542,6 +542,46 @@ def execute(self, context): ) +def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch, test_dags_dir): + """ + Test that the DAG parsing context is correctly set during the startup process. + + This test verifies that the DAG and task IDs are correctly set in the parsing context + when a DAG is started up. + """ + dag_id = "dag_parsing_context_test" + task_id = "conditional_task" + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + dag_rel_path="dag_parsing_context.py", + bundle_info=BundleInfo(name="my-bundle", version=None), + requests_fd=0, + ti_context=make_ti_context(dag_id=dag_id, run_id="c"), + ) + + mock_supervisor_comms.get_message.return_value = what + + # Set the environment variable for DAG bundles + # We use the DAG defined in `task_sdk/tests/dags/dag_parsing_context.py` for this test! + dag_bundle_val = json.dumps( + [ + { + "name": "my-bundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"local_folder": str(test_dags_dir), "refresh_interval": 1}, + } + ] + ) + + monkeypatch.setenv("AIRFLOW__DAG_BUNDLES__BACKENDS", dag_bundle_val) + ti, _ = startup() + + # Presence of `conditional_task` below means DAG ID is properly set in the parsing context! + # Check the dag file for the actual logic! + assert ti.task.dag.task_dict.keys() == {"visible_task", "conditional_task"} + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/tests/dags/test_dag_parsing_context.py b/tests/dags/test_dag_parsing_context.py index 3a72cbbc191787..acd6a712ee118f 100644 --- a/tests/dags/test_dag_parsing_context.py +++ b/tests/dags/test_dag_parsing_context.py @@ -20,7 +20,7 @@ from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator -from airflow.utils.dag_parsing_context import get_parsing_context +from airflow.sdk.definitions.context import get_parsing_context DAG_ID = "test_dag_parsing_context" diff --git a/tests/dags/test_parsing_context.py b/tests/dags/test_parsing_context.py index c901dbc7062d0d..4aae5caaf2d814 100644 --- a/tests/dags/test_parsing_context.py +++ b/tests/dags/test_parsing_context.py @@ -18,19 +18,16 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator -from airflow.utils.dag_parsing_context import ( +from airflow.sdk.definitions.context import ( _AIRFLOW_PARSING_CONTEXT_DAG_ID, _AIRFLOW_PARSING_CONTEXT_TASK_ID, + Context, ) from airflow.utils.timezone import datetime -if TYPE_CHECKING: - from airflow.sdk.definitions.context import Context - class DagWithParsingContext(EmptyOperator): def execute(self, context: Context):