From a8cbcd125e0e0437baec8fcafc0f94cd8ba6e615 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sun, 12 Jan 2025 14:03:44 +0530 Subject: [PATCH] Remove code for deprecation of Context keys We removed all the deprecated keys in https://github.com/apache/airflow/pull/43902 so we no longer need this code. In preparation of https://github.com/apache/airflow/pull/45583, I want to simplify this code. We can always revert/re-add this later when we need to deprecate a key. --- airflow/serialization/serialized_objects.py | 2 +- airflow/utils/context.py | 139 +----------------- airflow/utils/context.pyi | 1 - airflow/utils/operator_helpers.py | 16 +- .../providers/standard/operators/python.py | 6 +- .../tests/standard/operators/test_python.py | 5 +- 6 files changed, 19 insertions(+), 150 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 0926f3245e0db..41a80ed5fc359 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -781,7 +781,7 @@ def serialize( return cls._encode(var.to_json(), type_=DAT.DAG_CALLBACK_REQUEST) elif var.__class__ == Context: d = {} - for k, v in var._context.items(): + for k, v in var.items(): obj = cls.serialize(v, strict=strict) d[str(k)] = obj return cls._encode(d, type_=DAT.TASK_CONTEXT) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 45e487361e9a0..10cd44585019a 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -20,17 +20,10 @@ from __future__ import annotations import contextlib -import copy -import functools -import warnings from collections.abc import ( Container, - ItemsView, Iterator, - KeysView, Mapping, - MutableMapping, - ValuesView, ) from typing import ( TYPE_CHECKING, @@ -40,7 +33,6 @@ ) import attrs -import lazy_object_proxy from sqlalchemy import and_, select from airflow.exceptions import RemovedInAirflow3Warning @@ -367,97 +359,12 @@ class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): """Warn for usage of deprecated context variables in a task.""" -def _create_deprecation_warning(key: str, replacements: list[str]) -> RemovedInAirflow3Warning: - message = f"Accessing {key!r} from the template is deprecated and will be removed in a future version." - if not replacements: - return AirflowContextDeprecationWarning(message) - display_except_last = ", ".join(repr(r) for r in replacements[:-1]) - if display_except_last: - message += f" Please use {display_except_last} or {replacements[-1]!r} instead." - else: - message += f" Please use {replacements[-1]!r} instead." - return AirflowContextDeprecationWarning(message) - - -class Context(MutableMapping[str, Any]): - """ - Jinja2 template context for task rendering. - - This is a mapping (dict-like) class that can lazily emit warnings when - (and only when) deprecated context keys are accessed. - """ - - _DEPRECATION_REPLACEMENTS: dict[str, list[str]] = {} - - def __init__(self, context: MutableMapping[str, Any] | None = None, **kwargs: Any) -> None: - self._context: MutableMapping[str, Any] = context or {} - if kwargs: - self._context.update(kwargs) - self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy() - - def __repr__(self) -> str: - return repr(self._context) +class Context(dict[str, Any]): + """Jinja2 template context for task rendering.""" def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]: - """ - Pickle the context as a dict. - - We are intentionally going through ``__getitem__`` in this function, - instead of using ``items()``, to trigger deprecation warnings. - """ - items = [(key, self[key]) for key in self._context] - return dict, (items,) - - def __copy__(self) -> Context: - new = type(self)(copy.copy(self._context)) - new._deprecation_replacements = self._deprecation_replacements.copy() - return new - - def __getitem__(self, key: str) -> Any: - with contextlib.suppress(KeyError): - warnings.warn( - _create_deprecation_warning(key, self._deprecation_replacements[key]), - stacklevel=2, - ) - with contextlib.suppress(KeyError): - return self._context[key] - raise KeyError(key) - - def __setitem__(self, key: str, value: Any) -> None: - self._deprecation_replacements.pop(key, None) - self._context[key] = value - - def __delitem__(self, key: str) -> None: - self._deprecation_replacements.pop(key, None) - del self._context[key] - - def __contains__(self, key: object) -> bool: - return key in self._context - - def __iter__(self) -> Iterator[str]: - return iter(self._context) - - def __len__(self) -> int: - return len(self._context) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Context): - return NotImplemented - return self._context == other._context - - def __ne__(self, other: Any) -> bool: - if not isinstance(other, Context): - return NotImplemented - return self._context != other._context - - def keys(self) -> KeysView[str]: - return self._context.keys() - - def items(self): - return ItemsView(self._context) - - def values(self): - return ValuesView(self._context) + """Pickle the context as a dict.""" + return dict, (list(self.items()),) def context_merge(context: Mapping[str, Any], *args: Any, **kwargs: Any) -> None: @@ -505,46 +412,10 @@ def context_copy_partial(source: Mapping[str, Any], keys: Container[str]) -> Con :meta private: """ - new = Context({k: v for k, v in source._context.items() if k in keys}) - new._deprecation_replacements = source._deprecation_replacements.copy() + new = Context({k: v for k, v in source.items() if k in keys}) return new -def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: - """ - Create a mapping that wraps deprecated entries in a lazy object proxy. - - This further delays deprecation warning to until when the entry is actually - used, instead of when it's accessed in the context. The result is useful for - passing into a callable with ``**kwargs``, which would unpack the mapping - too eagerly otherwise. - - This is implemented as a free function because the ``Context`` type is - "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom - functions. - - :meta private: - """ - if not isinstance(source, Context): - # Sometimes we are passed a plain dict (usually in tests, or in User's - # custom operators) -- be lenient about what we accept so we don't - # break anything for users. - return source - - def _deprecated_proxy_factory(k: str, v: Any) -> Any: - replacements = source._deprecation_replacements[k] - warnings.warn(_create_deprecation_warning(k, replacements), stacklevel=2) - return v - - def _create_value(k: str, v: Any) -> Any: - if k not in source._deprecation_replacements: - return v - factory = functools.partial(_deprecated_proxy_factory, k, v) - return lazy_object_proxy.Proxy(factory) - - return {k: _create_value(k, v) for k, v in source._context.items()} - - def context_get_outlet_events(context: Context) -> OutletEventAccessors: try: return context["outlet_events"] diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index b08833623c420..1a19dc322ba45 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -143,5 +143,4 @@ def context_merge(context: Context, additions: Iterable[tuple[str, Any]], **kwar def context_merge(context: Context, **kwargs: Any) -> None: ... def context_update_for_unmapped(context: Mapping[str, Any], task: BaseOperator) -> None: ... def context_copy_partial(source: Context, keys: Container[str]) -> Context: ... -def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ... def context_get_outlet_events(context: Context) -> OutletEventAccessors: ... diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index 93bc9e53daf38..cb822aa1cc77b 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -26,7 +26,6 @@ from airflow import settings from airflow.sdk.definitions.asset.metadata import Metadata from airflow.typing_compat import ParamSpec -from airflow.utils.context import Context, lazy_mapping_from_context from airflow.utils.types import NOTSET if TYPE_CHECKING: @@ -151,9 +150,8 @@ class KeywordParameters: content and use it somewhere else without needing ``lazy-object-proxy``. """ - def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None: + def __init__(self, kwargs: Mapping[str, Any]) -> None: self._kwargs = kwargs - self._wildcard = wildcard @classmethod def determine( @@ -181,20 +179,14 @@ def determine( if has_wildcard_kwargs: # If the callable has a **kwargs argument, it's ready to accept all the kwargs. - return cls(kwargs, wildcard=True) + return cls(kwargs) # If the callable has no **kwargs argument, it only wants the arguments it requested. - kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs} - return cls(kwargs, wildcard=False) + filtered_kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs} + return cls(filtered_kwargs) def unpacking(self) -> Mapping[str, Any]: """Dump the kwargs mapping to unpack with ``**`` in a function call.""" - if self._wildcard and isinstance(self._kwargs, Context): # type: ignore[misc] - return lazy_mapping_from_context(self._kwargs) - return self._kwargs - - def serializing(self) -> Mapping[str, Any]: - """Dump the kwargs mapping for serialization purposes.""" return self._kwargs diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 1207d349d10b1..35bc488860604 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -581,7 +581,11 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): return self._read_result(output_path) def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: - return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing() + keyword_params = KeywordParameters.determine(self.python_callable, self.op_args, context) + if AIRFLOW_V_3_0_PLUS: + return keyword_params.unpacking() + else: + return keyword_params.serializing() # type: ignore[attr-defined] class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index e0cbf9e3c2d15..3899c89fae9d8 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -1939,7 +1939,10 @@ def get_all_the_context(**context): current_context = get_current_context() with warnings.catch_warnings(): warnings.simplefilter("ignore", AirflowContextDeprecationWarning) - assert context == current_context._context + if AIRFLOW_V_3_0_PLUS: + assert context == current_context + else: + assert current_context._context @pytest.fixture