diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 57a8c7af6..00e09696c 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -377,14 +377,24 @@ async def connect(url): @pytest.mark.asyncio @pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS) -async def test_app_close(protocol_cls): +@pytest.mark.parametrize("code", [None, 1000, 1001]) +@pytest.mark.parametrize("reason", [None, "test"]) +async def test_app_close(protocol_cls, code, reason): async def app(scope, receive, send): while True: message = await receive() if message["type"] == "websocket.connect": await send({"type": "websocket.accept"}) elif message["type"] == "websocket.receive": - await send({"type": "websocket.close"}) + reply = {"type": "websocket.close"} + + if code is not None: + reply["code"] = code + + if reason is not None: + reply["reason"] = reason + + await send(reply) elif message["type"] == "websocket.disconnect": break @@ -398,7 +408,8 @@ async def websocket_session(url): async with run_server(config): with pytest.raises(websockets.exceptions.ConnectionClosed) as exc_info: await websocket_session("ws://127.0.0.1:8000") - assert exc_info.value.code == 1000 + assert exc_info.value.code == (code or 1000) + assert exc_info.value.reason == (reason or "") @pytest.mark.asyncio diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 6057fec4a..4afd72f10 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -224,7 +224,8 @@ async def asgi_send(self, message): elif message_type == "websocket.close": code = message.get("code", 1000) - await self.close(code) + reason = message.get("reason", "") + await self.close(code, reason) self.closed_event.set() else: diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 456af1370..6d15f8342 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -279,8 +279,11 @@ async def send(self, message): elif message_type == "websocket.close": self.close_sent = True code = message.get("code", 1000) + reason = message.get("reason", "") self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) - output = self.conn.send(wsproto.events.CloseConnection(code=code)) + output = self.conn.send( + wsproto.events.CloseConnection(code=code, reason=reason) + ) if not self.transport.is_closing(): self.transport.write(output) self.transport.close()