From 162440c906c0c37e0e2219b61c4c8280285db17b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Rubin?= Date: Mon, 18 Dec 2023 16:20:24 +0100 Subject: [PATCH] feat: Improve Middleware type annotations. Use ParamSpec to provide concrete type annotations for middleware's parameters. --- starlette/applications.py | 17 +++++++++++++++-- starlette/middleware/__init__.py | 22 +++++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 554a25e65..3f6d55ea5 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import typing import warnings @@ -14,7 +15,14 @@ from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket +if sys.version_info >= (3, 10): # pragma: no cover + from typing import Concatenate, ParamSpec +else: # pragma: no cover + from typing_extensions import Concatenate, ParamSpec + + AppType = typing.TypeVar("AppType", bound="Starlette") +P = ParamSpec("P") class Starlette: @@ -124,10 +132,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: self.router.host(host, app=app, name=name) # pragma: no cover - def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: + def add_middleware( + self, + middleware_class: typing.Callable[Concatenate[ASGIApp, P], typing.Any], + *args: P.args, + **options: P.kwargs, + ) -> None: if self.middleware_stack is not None: # pragma: no cover raise RuntimeError("Cannot add middleware after an application has started") - self.user_middleware.insert(0, Middleware(middleware_class, **options)) + self.user_middleware.insert(0, Middleware(middleware_class, *args, **options)) def add_exception_handler( self, diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 05bd57f04..864ff5504 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -1,12 +1,28 @@ -import typing +import sys +from typing import Any, Callable, Iterator + +from starlette.types import ASGIApp + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import Concatenate, ParamSpec +else: # pragma: no cover + from typing_extensions import Concatenate, ParamSpec + + +P = ParamSpec("P") class Middleware: - def __init__(self, cls: type, **options: typing.Any) -> None: + def __init__( + self, + cls: Callable[Concatenate[ASGIApp, P], Any], + *args: P.args, + **options: P.kwargs, + ) -> None: self.cls = cls self.options = options - def __iter__(self) -> typing.Iterator[typing.Any]: + def __iter__(self) -> Iterator[Any]: as_tuple = (self.cls, self.options) return iter(as_tuple)