Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing trace log for websocket protocols #1083

Merged
merged 3 commits into from
Jun 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/middleware/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import httpx
import pytest
import websockets

from tests.utils import run_server
from uvicorn import Config
Expand Down Expand Up @@ -45,6 +46,56 @@ async def test_trace_logging(caplog):
assert "ASGI [2] Completed" in messages.pop(0)


@pytest.mark.asyncio
@pytest.mark.parametrize("http_protocol", [("h11"), ("httptools")])
async def test_trace_logging_on_http_protocol(http_protocol, caplog):
config = Config(app=app, log_level="trace", http=http_protocol)
with caplog_for_logger(caplog, "uvicorn.error"):
async with run_server(config):
async with httpx.AsyncClient() as client:
response = await client.get("http://127.0.0.1:8000")
assert response.status_code == 204
messages = [
record.message
for record in caplog.records
if record.name == "uvicorn.error"
]
assert any(" - HTTP connection made" in message for message in messages)
assert any(" - HTTP connection lost" in message for message in messages)


@pytest.mark.asyncio
@pytest.mark.parametrize("ws_protocol", [("websockets"), ("wsproto")])
async def test_trace_logging_on_ws_protocol(ws_protocol, caplog):
async def websocket_app(scope, receive, send):
assert scope["type"] == "websocket"
while True:
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
elif message["type"] == "websocket.disconnect":
break

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.open

config = Config(app=websocket_app, log_level="trace", ws=ws_protocol)
with caplog_for_logger(caplog, "uvicorn.error"):
async with run_server(config):
is_open = await open_connection("ws://127.0.0.1:8000")
assert is_open
messages = [
record.message
for record in caplog.records
if record.name == "uvicorn.error"
]
print(messages)
assert any(" - Upgrading to WebSocket" in message for message in messages)
assert any(" - WebSocket connection made" in message for message in messages)
assert any(" - WebSocket connection lost" in message for message in messages)


@pytest.mark.asyncio
@pytest.mark.parametrize("use_colors", [(True), (False), (None)])
async def test_access_logging(use_colors, caplog):
Expand Down
8 changes: 6 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def connection_made(self, transport):

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection made", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection lost", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
self.cycle.disconnected = True
Expand Down Expand Up @@ -256,6 +256,10 @@ def handle_upgrade(self, event):
self.transport.close()
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
for name, value in self.headers:
Expand Down
8 changes: 6 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def connection_made(self, transport):

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection made", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sConnection lost", prefix)
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
self.cycle.disconnected = True
Expand Down Expand Up @@ -168,6 +168,10 @@ def handle_upgrade(self):
self.transport.close()
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
method = self.scope["method"].encode()
output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
Expand Down
11 changes: 11 additions & 0 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import websockets
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory

from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl


Expand Down Expand Up @@ -74,10 +75,20 @@ def connection_made(self, transport):
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

super().connection_made(transport)

def connection_lost(self, exc):
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_completed_event.set()
super().connection_lost(exc)
if self.on_connection_lost is not None:
Expand Down
10 changes: 10 additions & 0 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import RemoteProtocolError

from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl

# Check wsproto version. We've build against 0.13. We don't know about 0.14 yet.
Expand Down Expand Up @@ -65,10 +66,19 @@ def connection_made(self, transport):
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

def connection_lost(self, exc):
if exc is not None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

if self.on_connection_lost is not None:
self.on_connection_lost()
if exc is None:
Expand Down