From 2a94a96e43a3830aa2d0fe90f3e21ac6d06bd788 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Mar 2023 09:33:34 -0600 Subject: [PATCH] Introduce lifespan state (#1818) Co-authored-by: Marcelo Trylesinski --- tests/protocols/test_http.py | 47 +++++++++++++++- tests/protocols/test_websocket.py | 55 +++++++++++++++++++ tests/test_auto_detection.py | 6 +- tests/test_lifespan.py | 25 +++++++++ uvicorn/lifespan/off.py | 3 + uvicorn/lifespan/on.py | 6 +- uvicorn/protocols/http/h11_impl.py | 20 ++++++- uvicorn/protocols/http/httptools_impl.py | 21 ++++++- uvicorn/protocols/websockets/auto.py | 2 +- .../protocols/websockets/websockets_impl.py | 15 ++++- uvicorn/protocols/websockets/wsproto_impl.py | 5 +- uvicorn/server.py | 14 +++-- 12 files changed, 202 insertions(+), 17 deletions(-) diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index c041b69d0..7ac1fb70c 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -2,12 +2,15 @@ import socket import threading import time +from typing import Optional, Union import pytest from tests.response import Response from uvicorn import Server from uvicorn.config import WS_PROTOCOLS, Config +from uvicorn.lifespan.off import LifespanOff +from uvicorn.lifespan.on import LifespanOn from uvicorn.main import ServerState from uvicorn.protocols.http.h11_impl import H11Protocol @@ -184,12 +187,23 @@ def add_done_callback(self, callback): pass -def get_connected_protocol(app, protocol_cls, **kwargs): +def get_connected_protocol( + app, + protocol_cls, + lifespan: Optional[Union[LifespanOff, LifespanOn]] = None, + **kwargs, +): loop = MockLoop() transport = MockTransport() config = Config(app=app, **kwargs) + lifespan = lifespan or LifespanOff(config) server_state = ServerState() - protocol = protocol_cls(config=config, server_state=server_state, _loop=loop) + protocol = protocol_cls( + config=config, + server_state=server_state, + app_state=lifespan.state.copy(), + _loop=loop, + ) protocol.connection_made(transport) return protocol @@ -980,3 +994,32 @@ async def app(scope, receive, send): protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"x-test-header: test value" in protocol.transport.buffer + + +@pytest.mark.anyio +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +async def test_lifespan_state(protocol_cls): + expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}] + + async def app(scope, receive, send): + expected_state = expected_states.pop(0) + assert scope["state"] == expected_state + # modifications to keys are not preserved + scope["state"]["a"] = 456 + # unless of course the value itself is mutated + scope["state"]["b"].append(2) + return await Response("Hi!")(scope, receive, send) + + lifespan = LifespanOn(config=Config(app=app)) + # skip over actually running the lifespan, that is tested + # in the lifespan tests + lifespan.state.update({"a": 123, "b": [1]}) + + for _ in range(2): + protocol = get_connected_protocol(app, protocol_cls, lifespan=lifespan) + protocol.data_received(SIMPLE_GET_REQUEST) + await protocol.loop.run_one() + assert b"HTTP/1.1 200 OK" in protocol.transport.buffer + assert b"Hi!" in protocol.transport.buffer + + assert not expected_states # consumed diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index bb5c47c83..6bc56dfa4 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,5 +1,6 @@ import asyncio import typing +from copy import deepcopy import httpx import pytest @@ -1087,3 +1088,57 @@ async def open_connection(url): async with run_server(config): headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_lifespan_state(ws_protocol_cls, http_protocol_cls, unused_tcp_port: int): + expected_states = [ + {"a": 123, "b": [1]}, + {"a": 123, "b": [1, 2]}, + ] + + actual_states = [] + + async def lifespan_app(scope, receive, send): + message = await receive() + assert message["type"] == "lifespan.startup" + scope["state"]["a"] = 123 + scope["state"]["b"] = [1] + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.shutdown.complete"}) + + class App(WebSocketResponse): + async def websocket_connect(self, message): + actual_states.append(deepcopy(self.scope["state"])) + self.scope["state"]["a"] = 456 + self.scope["state"]["b"].append(2) + await self.send({"type": "websocket.accept"}) + + async def open_connection(url): + async with websockets.connect(url) as websocket: + return websocket.open + + async def app_wrapper(scope, receive, send): + if scope["type"] == "lifespan": + return await lifespan_app(scope, receive, send) + else: + return await App(scope, receive, send) + + config = Config( + app=app_wrapper, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="on", + port=unused_tcp_port, + ) + async with run_server(config): + is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") + assert is_open + is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") + assert is_open + + assert expected_states == actual_states diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py index 9d596df05..2cd2f70c4 100644 --- a/tests/test_auto_detection.py +++ b/tests/test_auto_detection.py @@ -45,7 +45,7 @@ def test_loop_auto(): async def test_http_auto(): config = Config(app=app) server_state = ServerState() - protocol = AutoHTTPProtocol(config=config, server_state=server_state) + protocol = AutoHTTPProtocol(config=config, server_state=server_state, app_state={}) expected_http = "H11Protocol" if httptools is None else "HttpToolsProtocol" assert type(protocol).__name__ == expected_http @@ -54,6 +54,8 @@ async def test_http_auto(): async def test_websocket_auto(): config = Config(app=app) server_state = ServerState() - protocol = AutoWebSocketsProtocol(config=config, server_state=server_state) + protocol = AutoWebSocketsProtocol( + config=config, server_state=server_state, app_state={} + ) expected_websockets = "WSProtocol" if websockets is None else "WebSocketProtocol" assert type(protocol).__name__ == expected_websockets diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index d49a9dcf9..a9cb73e3a 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -166,6 +166,7 @@ async def asgi3app(scope, receive, send): assert scope == { "type": "lifespan", "asgi": {"version": "3.0", "spec_version": "2.0"}, + "state": {}, } async def test(): @@ -188,6 +189,7 @@ def asgi2app(scope): assert scope == { "type": "lifespan", "asgi": {"version": "2.0", "spec_version": "2.0"}, + "state": {}, } async def asgi(receive, send): @@ -245,3 +247,26 @@ async def test(): assert "the lifespan event failed" in error_messages.pop(0) assert "Application shutdown failed. Exiting." in error_messages.pop(0) loop.close() + + +def test_lifespan_state(): + async def app(scope, receive, send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + scope["state"]["foo"] = 123 + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.shutdown.complete"}) + + async def test(): + config = Config(app=app, lifespan="on") + lifespan = LifespanOn(config) + + await lifespan.startup() + assert lifespan.state == {"foo": 123} + await lifespan.shutdown() + + loop = asyncio.new_event_loop() + loop.run_until_complete(test()) + loop.close() diff --git a/uvicorn/lifespan/off.py b/uvicorn/lifespan/off.py index 7ec961b5f..e1516f16a 100644 --- a/uvicorn/lifespan/off.py +++ b/uvicorn/lifespan/off.py @@ -1,9 +1,12 @@ +from typing import Any, Dict + from uvicorn import Config class LifespanOff: def __init__(self, config: Config) -> None: self.should_exit = False + self.state: Dict[str, Any] = {} async def startup(self) -> None: pass diff --git a/uvicorn/lifespan/on.py b/uvicorn/lifespan/on.py index 0c650aab1..37e935f01 100644 --- a/uvicorn/lifespan/on.py +++ b/uvicorn/lifespan/on.py @@ -1,7 +1,7 @@ import asyncio import logging from asyncio import Queue -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Dict, Union from uvicorn import Config @@ -42,6 +42,7 @@ def __init__(self, config: Config) -> None: self.startup_failed = False self.shutdown_failed = False self.should_exit = False + self.state: Dict[str, Any] = {} async def startup(self) -> None: self.logger.info("Waiting for application startup.") @@ -79,9 +80,10 @@ async def shutdown(self) -> None: async def main(self) -> None: try: app = self.config.loaded_app - scope: LifespanScope = { + scope: LifespanScope = { # type: ignore[typeddict-item] "type": "lifespan", "asgi": {"version": self.config.asgi_version, "spec_version": "2.0"}, + "state": self.state, } await app(scope, self.receive, self.send) except BaseException as exc: diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index c1fb46c93..7a4c260f2 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -2,7 +2,17 @@ import http import logging import sys -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) from urllib.parse import unquote import h11 @@ -42,6 +52,7 @@ HTTPScope, ) + H11Event = Union[ h11.Request, h11.InformationalResponse, @@ -69,6 +80,7 @@ def __init__( self, config: Config, server_state: ServerState, + app_state: Dict[str, Any], _loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: @@ -89,6 +101,7 @@ def __init__( self.ws_protocol_class = config.ws_protocol_class self.root_path = config.root_path self.limit_concurrency = config.limit_concurrency + self.app_state = app_state # Timeouts self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None @@ -229,6 +242,7 @@ def handle_events(self) -> None: "raw_path": raw_path, "query_string": query_string, "headers": self.headers, + "state": self.app_state, } upgrade = self._get_upgrade() @@ -290,7 +304,9 @@ def handle_websocket_upgrade(self, event: H11Event) -> None: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] - config=self.config, server_state=self.server_state + config=self.config, + server_state=self.server_state, + app_state=self.app_state, ) protocol.connection_made(self.transport) protocol.data_received(b"".join(output)) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 734e8945d..d7e6c33f0 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -6,7 +6,18 @@ import urllib from asyncio.events import TimerHandle from collections import deque -from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) import httptools @@ -44,6 +55,7 @@ HTTPScope, ) + HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]') HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]") @@ -66,6 +78,7 @@ def __init__( self, config: Config, server_state: ServerState, + app_state: Dict[str, Any], _loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: @@ -81,6 +94,7 @@ def __init__( self.ws_protocol_class = config.ws_protocol_class self.root_path = config.root_path self.limit_concurrency = config.limit_concurrency + self.app_state = app_state # Timeouts self.timeout_keep_alive_task: Optional[TimerHandle] = None @@ -201,7 +215,9 @@ def handle_websocket_upgrade(self) -> None: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] - config=self.config, server_state=self.server_state + config=self.config, + server_state=self.server_state, + app_state=self.app_state, ) protocol.connection_made(self.transport) protocol.data_received(b"".join(output)) @@ -237,6 +253,7 @@ def on_message_begin(self) -> None: "scheme": self.scheme, "root_path": self.root_path, "headers": self.headers, + "state": self.app_state, } # Parser callbacks diff --git a/uvicorn/protocols/websockets/auto.py b/uvicorn/protocols/websockets/auto.py index 0dfba3bdb..368b98242 100644 --- a/uvicorn/protocols/websockets/auto.py +++ b/uvicorn/protocols/websockets/auto.py @@ -1,7 +1,7 @@ import asyncio import typing -AutoWebSocketsProtocol: typing.Optional[typing.Type[asyncio.Protocol]] +AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]] try: import websockets # noqa except ImportError: # pragma: no cover diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 297203ec6..3d01fb282 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -2,7 +2,17 @@ import http import logging import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) from urllib.parse import unquote import websockets @@ -61,6 +71,7 @@ def __init__( self, config: Config, server_state: ServerState, + app_state: Dict[str, Any], _loop: Optional[asyncio.AbstractEventLoop] = None, ): if not config.loaded: @@ -70,6 +81,7 @@ def __init__( self.app = config.loaded_app self.loop = _loop or asyncio.get_event_loop() self.root_path = config.root_path + self.app_state = app_state # Shared server state self.connections = server_state.connections @@ -190,6 +202,7 @@ async def process_request( "query_string": query_string.encode("ascii"), "headers": asgi_headers, "subprotocols": subprotocols, + "state": self.app_state, } task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 1d76f3a88..030236394 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -49,6 +49,7 @@ def __init__( self, config: Config, server_state: ServerState, + app_state: typing.Dict[str, typing.Any], _loop: typing.Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: @@ -59,6 +60,7 @@ def __init__( self.loop = _loop or asyncio.get_event_loop() self.logger = logging.getLogger("uvicorn.error") self.root_path = config.root_path + self.app_state = app_state # Shared server state self.connections = server_state.connections @@ -170,7 +172,7 @@ def handle_connect(self, event: events.Request) -> None: headers = [(b"host", event.host.encode())] headers += [(key.lower(), value) for key, value in event.extra_headers] raw_path, _, query_string = event.target.partition("?") - self.scope: "WebSocketScope" = { + self.scope: "WebSocketScope" = { # type: ignore[typeddict-item] "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, "http_version": "1.1", @@ -184,6 +186,7 @@ def handle_connect(self, event: events.Request) -> None: "headers": headers, "subprotocols": event.subprotocols, "extensions": None, + "state": self.app_state, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) diff --git a/uvicorn/server.py b/uvicorn/server.py index a3fb31b2b..2d6c02ff0 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -1,5 +1,4 @@ import asyncio -import functools import logging import os import platform @@ -92,9 +91,16 @@ async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None: config = self.config - create_protocol = functools.partial( - config.http_protocol_class, config=config, server_state=self.server_state - ) + def create_protocol( + _loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> asyncio.Protocol: + return config.http_protocol_class( # type: ignore[call-arg] + config=config, + server_state=self.server_state, + app_state=self.lifespan.state.copy(), + _loop=_loop, + ) + loop = asyncio.get_running_loop() listeners: Sequence[socket.SocketType]