diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 4199a47130fb9..2ff417985e887 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License. -#} +from __future__ import annotations import {{ pickling_library }} import sys diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index b91bcaae36be0..0f7ab6918dd2c 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -21,12 +21,14 @@ import sys from importlib.util import find_spec from subprocess import CalledProcessError +from typing import Any import pytest from airflow.decorators import setup, task, teardown from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState pytestmark = pytest.mark.db_test @@ -37,6 +39,8 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") +_Invalid = Any + class TestPythonVirtualenvDecorator: @CLOUDPICKLE_MARKER @@ -350,3 +354,29 @@ def f(): assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_invalid_annotation(self, dag_maker): + import uuid + + unique_id = uuid.uuid4().hex + value = {"unique_id": unique_id} + + # Functions that throw an error + # if `from __future__ import annotations` is missing + @task.virtualenv(multiple_outputs=False, do_xcom_push=True) + def in_venv(value: dict[str, _Invalid]) -> _Invalid: + assert isinstance(value, dict) + return value["unique_id"] + + with dag_maker(): + ret = in_venv(value) + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + + assert ti.state == TaskInstanceState.SUCCESS + + xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") + assert isinstance(xcom, str) + assert xcom == unique_id