diff --git a/airflow/__init__.py b/airflow/__init__.py index c19d298092454..b2f58b5bc6dcb 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -22,6 +22,7 @@ import os import sys import warnings +from typing import TYPE_CHECKING if os.environ.get("_AIRFLOW_PATCH_GEVENT"): # If you are using gevents and start airflow webserver, you might want to run gevent monkeypatching @@ -81,6 +82,13 @@ # Deprecated lazy imports "AirflowException": (".exceptions", "AirflowException", True), } +if TYPE_CHECKING: + # These objects are imported by PEP-562, however, static analyzers and IDE's + # have no idea about typing of these objects. + # Add it under TYPE_CHECKING block should help with it. + from airflow.models.dag import DAG + from airflow.models.dataset import Dataset + from airflow.models.xcom_arg import XComArg def __getattr__(name: str): @@ -119,9 +127,9 @@ def __getattr__(name: str): if not settings.LAZY_LOAD_PROVIDERS: - from airflow import providers_manager + from airflow.providers_manager import ProvidersManager - manager = providers_manager.ProvidersManager() + manager = ProvidersManager() manager.initialize_providers_list() manager.initialize_providers_hooks() manager.initialize_providers_extra_links() @@ -129,14 +137,3 @@ def __getattr__(name: str): from airflow import plugins_manager plugins_manager.ensure_plugins_loaded() - - -# This is never executed, but tricks static analyzers (PyDev, PyCharm,) -# into knowing the types of these symbols, and what -# they contain. -STATICA_HACK = True -globals()["kcah_acitats"[::-1].upper()] = False -if STATICA_HACK: # pragma: no cover - from airflow.models.dag import DAG - from airflow.models.dataset import Dataset - from airflow.models.xcom_arg import XComArg diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index cf78673c057cc..244a1c4710727 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -31,13 +31,14 @@ from dataclasses import dataclass from functools import wraps from time import perf_counter -from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, MutableMapping, NamedTuple, NoReturn, TypeVar from packaging.utils import canonicalize_name from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.hooks.filesystem import FSHook from airflow.hooks.package_index import PackageIndexHook +from airflow.typing_compat import ParamSpec from airflow.utils import yaml from airflow.utils.entry_points import entry_points_with_dist from airflow.utils.log.logging_mixin import LoggingMixin @@ -51,6 +52,9 @@ else: from importlib_resources import files as resource_files +PS = ParamSpec("PS") +RT = TypeVar("RT") + MIN_PROVIDER_VERSIONS = { "apache-airflow-providers-celery": "2.1.0", } @@ -261,11 +265,6 @@ class ConnectionFormWidgetInfo(NamedTuple): is_sensitive: bool -T = TypeVar("T", bound=Callable) - -logger = logging.getLogger(__name__) - - def log_debug_import_from_sources(class_name, e, provider_package): """Log debug imports from sources.""" log.debug( @@ -362,7 +361,7 @@ def _correctness_check(provider_package: str, class_name: str, provider_info: Pr # We want to have better control over initialization of parameters and be able to debug and test it # So we add our own decorator -def provider_info_cache(cache_name: str) -> Callable[[T], T]: +def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, NoReturn]], Callable[PS, None]]: """ Decorate and cache provider info. @@ -370,23 +369,26 @@ def provider_info_cache(cache_name: str) -> Callable[[T], T]: :param cache_name: Name of the cache """ - def provider_info_cache_decorator(func: T): + def provider_info_cache_decorator(func: Callable[PS, NoReturn]) -> Callable[PS, None]: @wraps(func) - def wrapped_function(*args, **kwargs): + def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None: providers_manager_instance = args[0] + if TYPE_CHECKING: + assert isinstance(providers_manager_instance, ProvidersManager) + if cache_name in providers_manager_instance._initialized_cache: return start_time = perf_counter() - logger.debug("Initializing Providers Manager[%s]", cache_name) + log.debug("Initializing Providers Manager[%s]", cache_name) func(*args, **kwargs) providers_manager_instance._initialized_cache[cache_name] = True - logger.debug( + log.debug( "Initialization of Providers Manager[%s] took %.2f seconds", cache_name, perf_counter() - start_time, ) - return cast(T, wrapped_function) + return wrapped_function return provider_info_cache_decorator diff --git a/pyproject.toml b/pyproject.toml index 39a0441488392..d47675c2650cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -343,7 +343,7 @@ required-imports = ["from __future__ import annotations"] combine-as-imports = true [tool.ruff.lint.per-file-ignores] -"airflow/__init__.py" = ["F401"] +"airflow/__init__.py" = ["F401", "TCH004"] "airflow/models/__init__.py" = ["F401", "TCH004"] "airflow/models/sqla_models.py" = ["F401"]