diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d21be54de..94b625a36 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,3 +1,5 @@ +import asyncio + import httpx import pytest @@ -539,3 +541,67 @@ async def send_text(url): with pytest.raises(websockets.ConnectionClosedError) as e: data = await send_text("ws://127.0.0.1:8000") assert e.value.code == expected_result + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection(ws_protocol_cls, http_protocol_cls): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + # Reject the connection. + await send({"type": "websocket.close"}) + # -- At this point websockets' recv() is unusable. -- + + # This doesn't raise `TypeError`: + # See https://github.com/encode/uvicorn/issues/244 + message = await receive() + assert message["type"] == "websocket.disconnect" + + async def websocket_session(url): + try: + async with websockets.connect(url): + pass + except Exception: + pass + + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + await websocket_session("ws://127.0.0.1:8000") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_can_read_messages_in_buffer_after_close( + ws_protocol_cls, http_protocol_cls +): + frames = [] + + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + # Ensure server doesn't start reading frames from read buffer until + # after client has sent close frame, but server is still able to + # read these frames + await asyncio.sleep(0.2) + + async def websocket_receive(self, message): + frames.append(message.get("bytes")) + + async def send_text(url): + async with websockets.connect(url) as websocket: + await websocket.send(b"abc") + await websocket.send(b"abc") + await websocket.send(b"abc") + + config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + await send_text("ws://127.0.0.1:8000") + + assert frames == [b"abc", b"abc", b"abc"] diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index d829f49a6..b69e31535 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -282,10 +282,17 @@ async def asgi_receive(self): return {"type": "websocket.connect"} await self.handshake_completed_event.wait() + + if self.closed_event.is_set(): + # If client disconnected, use WebSocketServerProtocol.close_code property. + # If the handshake failed or the app closed before handshake completion, + # use 1006 Abnormal Closure. + return {"type": "websocket.disconnect", "code": self.close_code or 1006} + try: - await self.ensure_open() data = await self.recv() except websockets.ConnectionClosed as exc: + self.closed_event.set() return {"type": "websocket.disconnect", "code": exc.code} msg = {"type": "websocket.receive"}