Skip to content

Commit

Permalink
Add more hooks to support structlog contextvars (#36)
Browse files Browse the repository at this point in the history
* Add more hooks to support structlog contextvars

* flake8

* update comment

* remove response headers, deprecated in favor of more robust connect hooks

* comments

* changelog
  • Loading branch information
Matt Kafonek authored Aug 25, 2022
1 parent 51ddbec commit 292e2d0
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 57 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- CI/CD files and noxfile syntax
- `WebsocketManager` Backend
- New extra install `-E websockets`, additionally a convenience `-E all` option
- `WebsocketManager` saves response headers on connect
- `context_hook` in Base Manager that can be used to bind structlog contextvars for all workers (inbound, outbound, poll)
- `connect_hook` and `disconnect_hook` for Websocket manager

### Changed
- Use `managed_service_fixtures` for Redis tests
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pytest-cov = "^3.0.0"
uvicorn = "^0.18.2"
fastapi = "^0.79.0"
httpx = "^0.23.0"
structlog = "^22.1.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
41 changes: 33 additions & 8 deletions sending/backends/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ def __init__(self, ws_url: str):
super().__init__()
self.ws_url = ws_url

# Save / overwrite response headers on each connection
self.response_headers = {}

# An unauth_ws and authed_ws pair of Futures are created so that
# sub-classes can easily implement a pattern where messages are only
# sent to the server after the session has been authenticated.
self.unauth_ws = asyncio.Future()
self.authed_ws = asyncio.Future()
# Can use await mgr.connected.wait() to block until the websocket is connected
# in tests or in situations where you want connect_hook / context_hook to have
# information available to it from the websocket response (e.g. RTU session id)
self.connected = asyncio.Event()

# When an outbound worker is ready to send a message over the wire, it
# calls ._publish which will await the unauth_ws or authed_ws Future.
Expand All @@ -63,13 +64,28 @@ def __init__(self, ws_url: str):
self.max_reconnections = None

# Optional hooks that can be defined in a subclass or attached to an instance.
# First auth_hook then init_hook are called immediately after websocket connection
# They're separated into two hooks in case a subclass needs to authenticate, and
# also queue up initial messages that use .send() (which will wait until auth is accepted)
# - connect_hook is called first when websocket is established, useful to
# set contextvars or store state before init / auth hooks are called
#
# - auth_hook is called next, and also effects how .send() works.
# If auth_hook is defined then .send() won't actually transmit data over the wire
# until on_auth callback has been triggered.
# You want to define an auth_hook if the websocket server expects a first message
# to be some kind of authentication
#
# - init_hook is called next after auth_hook, useful to kick off messages after
# auth_hook, or if authentication is not part of the websocket server flow.
#
# - disconnect_hook is called when the websocket connection is lost
#
if not hasattr(self, "auth_hook"):
self.auth_hook: Optional[Callable] = None
if not hasattr(self, "init_hook"):
self.init_hook: Optional[Callable] = None
if not hasattr(self, "connect_hook"):
self.connect_hook: Optional[Callable] = None
if not hasattr(self, "disconnect_hook"):
self.disconnect_hook: Optional[Callable] = None

self.register_callback(self.record_last_seen_message)

Expand Down Expand Up @@ -147,9 +163,14 @@ async def _poll_loop(self):
"""
# Automatic reconnect https://websockets.readthedocs.io/en/stable/reference/client.html
async for websocket in websockets.connect(self.ws_url):
self.response_headers = dict(websocket.response_headers)
self.unauth_ws.set_result(websocket)
if self.connect_hook:
fn = ensure_async(self.connect_hook)
await fn(self)
if self.context_hook:
await self.context_hook()
self.connected.set()
try:
self.unauth_ws.set_result(websocket)
# Call the auth and init hooks (casting to async if necessary), passing in 'self'
if self.auth_hook:
fn = ensure_async(self.auth_hook)
Expand All @@ -175,6 +196,10 @@ async def _poll_loop(self):
logger.warning("Hit max reconnection attempts, not reconnecting")
return await self.shutdown()
logger.info("Websocket server disconnected, resetting Futures and reconnecting")
if self.disconnect_hook:
fn = ensure_async(self.disconnect_hook)
await fn(self)
self.connected.clear()
self.unauth_ws = asyncio.Future()
self.authed_ws = asyncio.Future()
self.reconnections += 1
Expand Down
23 changes: 19 additions & 4 deletions sending/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import enum
from collections import defaultdict, namedtuple
from functools import partial, wraps
from typing import Callable, Coroutine, Dict, List, Set
from typing import Callable, Dict, List, Optional, Set
from uuid import UUID, uuid4

from .logging import logger
Expand Down Expand Up @@ -53,9 +53,17 @@ def __init__(self):

# Allow these hooks to be defined within the class or attached to an instance
if not hasattr(self, "inbound_message_hook"):
self.inbound_message_hook: Coroutine = None
# Called by _inbound_worker when picking up a message from inbound queue
# Primarily used for deserializing messages from the wire
self.inbound_message_hook: Optional[Callable] = None
if not hasattr(self, "outbound_message_hook"):
self.outbound_message_hook: Coroutine = None
# Called by _outbound_worker before pushing a message to _publish
# Primarily used for serializing messages going out over the wire
self.outbound_message_hook: Optional[Callable] = None
if not hasattr(self, "context_hook"):
# Called at .initialize() and then within the while True loop for
# each worker. Should be used to set structlog.contextvars.bind_contextvars.
self.context_hook: Optional[Callable] = None

async def initialize(
self,
Expand Down Expand Up @@ -87,6 +95,8 @@ def echo(msg: str):
QueuedMessage(topic="test-topic, contents="echo test", session_id=None)
)
"""
if self.context_hook:
await self.context_hook()
self.outbound_queue = asyncio.Queue(queue_size)
self.inbound_queue = asyncio.Queue(queue_size)

Expand Down Expand Up @@ -253,6 +263,8 @@ def _detach_callback(self, cb_id: UUID, _session_id: UUID):
async def _outbound_worker(self):
while True:
message = await self.outbound_queue.get()
if self.context_hook:
await self.context_hook()
try:
if self.outbound_message_hook is not None:
coro = ensure_async(self.outbound_message_hook)
Expand All @@ -276,7 +288,8 @@ async def _publish(self, message: QueuedMessage):
async def _inbound_worker(self):
while True:
message = await self.inbound_queue.get()

if self.context_hook:
await self.context_hook()
try:
if self.inbound_message_hook is not None and message.topic is not SYSTEM_TOPIC:
coro = ensure_async(self.inbound_message_hook)
Expand Down Expand Up @@ -319,6 +332,8 @@ async def _delegate_to_callback(self, message: QueuedMessage, callback_id: UUID)

async def _poll_loop(self):
while True:
if self.context_hook:
await self.context_hook()
try:
await self._poll()
except Exception:
Expand Down
116 changes: 75 additions & 41 deletions tests/test_websocket_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import httpx
import pytest
import structlog
from managed_service_fixtures import AppDetails, AppManager

from sending.backends.websocket import WebsocketManager
Expand Down Expand Up @@ -73,22 +74,6 @@ async def test_basic_send(manager: WebsocketManager):
assert reply == {"type": "unauthed_echo_reply", "text": "Hello plain_manager"}


async def test_response_header(manager: WebsocketManager):
"""
The test websocket server will return a handful of response headers as part of the
websocket connection, such as {'upgrade': 'websocket', 'connection': 'Upgrade'}, etc.
It also includes a custom header {'foo': 'bar'}. In real world situations that response
header might be something like {'rtu-session-id': '<uuid>'}. Test that we save off
the response headers to the Manager instance on connect.
"""
await manager.initialize()
# Give it a moment to connect
await asyncio.sleep(0.01)
assert (
manager.response_headers.get("foo") == "bar"
), f"Missing foo: bar from {dict(manager.response_headers)}"


async def test_message_hooks(json_manager: WebsocketManager):
"""
Test that the inbound and outbound message hooks serialize/deserialize json
Expand All @@ -113,31 +98,6 @@ async def init_hook(mgr: WebsocketManager):
assert reply == {"type": "unauthed_echo_reply", "text": "Hello init_hook"}


@pytest.mark.xfail(reason="I don't know why this test doesn't work. Nick HALP")
async def test_bad_auth_hook(json_manager: WebsocketManager):
"""
Test that if someone adds an auth_hook but forgets to attach
a callback which will call .on_auth, that the ._publish method will
time out instead of awaiting .authed_ws forever
"""

async def auth_hook(mgr: WebsocketManager):
# auth_hook can't use mgr.send because that is goign to wait for authed_ws,
# which doesn't get set until auth reply!
# So send over the unauth_ws Future.
# Also note that the outbound_message_hook isn't applied!
ws = await mgr.unauth_ws
msg = json.dumps({"type": "auth_request", "token": str(uuid.UUID(int=1))})
await ws.send(msg)

json_manager.auth_hook = auth_hook
json_manager.publish_timeout = 1
await json_manager.initialize()
with pytest.raises(asyncio.TimeoutError):
await json_manager.send({"type": "authed_echo_request", "text": "Hello auth"})
await asyncio.sleep(2)


async def test_auth_hook(json_manager: WebsocketManager):
"""
Test that an auth_hook is called immediately after websocket connection,
Expand Down Expand Up @@ -237,6 +197,80 @@ async def auth_hook(self, mgr):
await mgr.shutdown()


# Two fixtures below used by test_structlog_contextvars_worker_hook
# Pattern pulled from https://www.structlog.org/en/stable/testing.html
@pytest.fixture(name="log_output")
def fixture_log_output():
return structlog.testing.LogCapture()


@pytest.fixture
def fixture_configure_structlog(log_output):
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.CallsiteParameterAdder(
{structlog.processors.CallsiteParameter.FUNC_NAME}
),
log_output,
]
)


@pytest.mark.usefixtures("fixture_configure_structlog")
async def test_structlog_contextvars_worker_hook(websocket_server: AppDetails, log_output):
"""
Test that we can bind contextvars within the context_hook method and that any callbacks
or outbound publishing methods will include those in logs.
"""

class Sub(WebsocketManager):
def __init__(self, ws_url):
super().__init__(ws_url)
self.session_id = None
self.register_callback(self.log_received)

async def context_hook(self):
structlog.contextvars.bind_contextvars(session_id=self.session_id)

async def connect_hook(self, mgr):
ws = await self.unauth_ws
self.session_id = ws.response_headers.get("session_id")

async def inbound_message_hook(self, raw_contents: str):
return json.loads(raw_contents)

async def outbound_message_hook(self, msg: dict):
return json.dumps(msg)

async def _publish(self, message: QueuedMessage):
await super()._publish(message)
structlog.get_logger().info(f"Publishing {message.contents}")

async def log_received(self, message: dict):
structlog.get_logger().info(f"Received {message}")

mgr = Sub(ws_url=websocket_server.ws_base + "/ws")
await mgr.initialize()
# Wait until we're connected before sending a message, otherwise the outbound worker
# will drop into .send / ._publish before we have a session_id set
await mgr.connected.wait()
mgr.send({"type": "unauthed_echo_request", "text": "Hello 1"})
# move forward in time until we get the next message from the webserver
await mgr.next_event.wait()
publish_log = log_output.entries[0]
assert publish_log["event"] == 'Publishing {"type": "unauthed_echo_request", "text": "Hello 1"}'
assert publish_log["session_id"]
assert publish_log["func_name"] == "_publish"

receive_log = log_output.entries[1]
assert receive_log["event"] == "Received {'type': 'unauthed_echo_reply', 'text': 'Hello 1'}"
assert receive_log["session_id"]
assert receive_log["func_name"] == "log_received"

await mgr.shutdown()


async def test_disable_polling(mocker):
"""
Test that registered callbacks (record_last_seen_message) are still called
Expand Down
2 changes: 1 addition & 1 deletion tests/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def singleton(cls):
dependency = singleton

async def connect(self, session: WebsocketSession):
await session.ws.accept(headers=[(b"foo", b"bar")])
await session.ws.accept(headers=[(b"session_id", str(uuid.uuid4()).encode())])
self.sessions.append(session)

async def disconnect(self, session: WebsocketSession):
Expand Down

0 comments on commit 292e2d0

Please sign in to comment.