Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZMQ socket monitoring refactor #52

Merged
merged 7 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use `managed_service_fixtures` for Redis tests
- `WebsocketManager` backend uses vanilla `logging` instead of `structlog`, remove need for `structlog` dependency once `managed-service-fixtures` also drops it
- `JupyterBackend` introduce a short sleep in its poll loop while investigating 100% CPU usage
- `JupyterBackend` zmq polling changed fairly significantly to avoid missing messages while reconnecting socket after a max message size disconnect

## [0.2.2] - 2022-07-28
### Changed
Expand Down
221 changes: 147 additions & 74 deletions sending/backends/jupyter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from queue import Empty
from typing import Optional, Union
import collections
from typing import Dict, List, Optional, Union

import jupyter_client.session
import zmq
from jupyter_client import AsyncKernelClient
from jupyter_client.channels import ZMQSocketChannel
from zmq import NOBLOCK, Event, Socket, SocketOption, pyzmq_version, zmq_version
from zmq.asyncio import Context
from zmq.utils.monitor import recv_monitor_message

Expand All @@ -18,21 +19,26 @@ def __init__(
connection_info: dict,
*,
max_message_size: int = None,
sleep_between_polls: float = 0.005,
):
super().__init__()
self.connection_info = connection_info
self._monitor_sockets_for_topic: dict[str, Socket] = {}
# If max_message_size is set, we'll disconnect (and reconnect immediately) to zmq
# channels that try to send a message greater than that size. It prevents applications
# from OOM crashing reading in large outputs or other messages
self.max_message_size = max_message_size
self.sleep_between_polls = sleep_between_polls
# Tasks that ultiamtely watch the zmq channels for messages. Keep track of these
# for cleanup (unsubscribe_from_topic, shutdown)
self.channel_tasks: Dict[str, List[asyncio.Task]] = collections.defaultdict(list)

async def initialize(
self, *, queue_size=0, inbound_workers=1, outbound_workers=1, poll_workers=1
):
logger.debug(f"Initializing Jupyter Kernel Manager: {zmq_version()=}, {pyzmq_version()=}")
logger.debug(
f"Initializing Jupyter Kernel Manager: {zmq.zmq_version()=}, {zmq.pyzmq_version()=}"
)
self._context = Context()
if self.max_message_size:
self.set_context_option(SocketOption.MAXMSGSIZE, self.max_message_size)
self.set_context_option(zmq.SocketOption.MAXMSGSIZE, self.max_message_size)

self._client = AsyncKernelClient(context=self._context)
self._client.load_connection_info(self.connection_info)
Expand All @@ -45,32 +51,18 @@ async def initialize(
)

async def shutdown(self, now=False):
# Cancelling channel watching tasks here is equivalent to shutting down poll worker
# in other Sending backend implementations.
for topic_name, task_list in self.channel_tasks.items():
for task in task_list:
task.cancel()
await super().shutdown(now)
# https://github.com/zeromq/pyzmq/issues/1003
self._context.destroy(linger=0)

def set_context_option(self, option: int, val: Union[int, bytes]):
self._context.setsockopt(option, val)

async def _create_topic_subscription(self, topic_name: str):
if hasattr(self._client, f"{topic_name}_channel"):
channel_obj = getattr(self._client, f"{topic_name}_channel")
channel_obj.start()

monitor_socket = channel_obj.socket.get_monitor_socket()
self._monitor_sockets_for_topic[topic_name] = monitor_socket

async def _cleanup_topic_subscription(self, topic_name: str):
if hasattr(self._client, f"{topic_name}_channel"):
channel_obj = getattr(self._client, f"{topic_name}_channel")
channel_obj.socket.disable_monitor()
channel_obj.close()

# Reset the underlying channel object so jupyter_client will recreate it
# if we subscribe to this again.
setattr(self._client, f"_{topic_name}_channel", None)
del self._monitor_sockets_for_topic[topic_name]

def send(
self,
topic_name: str,
Expand All @@ -80,64 +72,145 @@ def send(
header: Optional[dict] = None,
metadata: Optional[dict] = None,
):
msg = self._client.session.msg(msg_type, content, parent, header, metadata)
self.outbound_queue.put_nowait(QueuedMessage(topic_name, msg, None))
"""
Put a message onto the outbound queue which will be picked up by the outbound
worker and sent over zmq to the Kernel. Most messages will get sent over the shell
channel, although some may go over control as wel.

Example:
mgr.send("shell", "execute_request", {"code": "print('hello')", "silent": False})
"""
# format the message into a Jupyter specced dictionary then drop into outbound queue
# to get sent over the wire when outbound worker calls ._publish
jupyter_session: jupyter_client.session.Session = self._client.session
jupyter_msg: dict = jupyter_session.msg(msg_type, content, parent, header, metadata)
self.outbound_queue.put_nowait(QueuedMessage(topic_name, jupyter_msg, None))

async def _publish(self, message: QueuedMessage):
"""
When the outbound worker observes a message on the outbound queue, it will call this
method to actually send the message over the wire.
"""
topic_name = message.topic
if topic_name not in self.subscribed_topics:
await self._create_topic_subscription(topic_name)
if hasattr(self._client, f"{topic_name}_channel"):
channel_obj = getattr(self._client, f"{topic_name}_channel")
channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel")
channel_obj.send(message.contents)

def _cycle_socket(self, topic):
channel_obj = getattr(self._client, f"{topic}_channel")
channel_obj.socket.disable_monitor()
channel_obj.close()
connect_fn = getattr(self._client, f"connect_{topic}")
channel_obj.socket = connect_fn()
monitor_socket = channel_obj.socket.get_monitor_socket()
self._monitor_sockets_for_topic[topic] = monitor_socket

# Normally in Sending backends there is the concept of a poll_worker which calls into _poll
# as part of a custom _poll_loop implementation. The poll_worker is what reads data over the
# wire (redis, socket, websocket, etc. zmq in the case of Jupyter/ipykernel). However the way
# this backend is written, reading data from zmq after subscribe_to_topic is called is handled
# by _watch_channel task (and its child tasks). poll_worker and these _poll methods do nothing.
async def _poll(self):
for topic_name in self.subscribed_topics:
channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel")

while True:
try:
msg = await channel_obj.get_msg(timeout=0)
self.schedule_for_delivery(topic_name, msg)
except Empty:
break

topics_to_cycle = []
for topic, socket in self._monitor_sockets_for_topic.items():
while await socket.poll(0):
msg = await recv_monitor_message(socket, flags=NOBLOCK)
logger.debug(f"ZMQ event: {topic=} {msg['event']=} {msg['value']=}")
if msg["event"] == Event.DISCONNECTED:
self._emit_system_event(topic, SystemEvents.FORCED_DISCONNECT)
topics_to_cycle.append(topic)

for topic in topics_to_cycle:
# If the ZMQ socket is disconnected, try cycling it
# This is helpful in situations where ZMQ disconnects peers
# when it violates some constraint such as the max message size.
logger.info(f"ZMQ disconnected for topic '{topic}', cycling socket")
self._cycle_socket(topic)
pass

async def _poll_loop(self):
pass

async def _create_topic_subscription(self, topic_name: str):
"""
Start observing messages on a zmq channel after a call to mgr.subscribe_to_topic('iopub')
"""
task = asyncio.create_task(self._watch_channel(topic_name))
self.channel_tasks[topic_name].append(task)

async def _cleanup_topic_subscription(self, topic_name: str):
"""
Clean up channel observing tasks after a call to mgr.unsubscribe_from_topic('iopub')
"""
if topic_name in self.channel_tasks:
for task in self.channel_tasks[topic_name]:
task.cancel()
self.channel_tasks[topic_name].remove(task)
await asyncio.sleep(0)
# Reset the channel object on our jupyter_client
setattr(self._client, f"_{topic_name}_channel", None)
else:
logger.warning(
f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict"
)

async def _watch_channel(self, topic_name: str):
"""
When a user subscribes to a topic (mgr.subscribe_to_topic('iopub')), this function starts
two child tasks:
1. Pull messages off the zmq channel and trigger any registered callbacks
2. Watch the monitor socket for disconnect events and reconnect / restart tasks

If a disconnect is observed, the two tasks are both cancelled and restarted.
Unsubscribing from a topic cancels this task and the child tasks.
"""
channel_name = f"{topic_name}_channel"

# context_hook (primarily for adding structlog contextvars) is normally called in base
# _poll_loop, so that it's applied to every read the poll worker does. For the Jupyter
# implementation, we don't use _poll_loop, all reads from zmq start with tasks here,
# and any contextvars set in this method will be picked up by tasks created here.
if self.context_hook:
await self.context_hook()
while True:
# The channel properties (e.g. self._client.iopub_channel) will connect the socket
# if self._client._iopub_channel is None. Channel objects have a monitor object
# to observe lifecycle of the socket such as handshake / disconnect
channel_obj: ZMQSocketChannel = getattr(self._client, channel_name)
message_task = asyncio.create_task(
self._watch_for_channel_messages(topic_name, channel_obj)
)

monitor_socket = channel_obj.socket.get_monitor_socket()
monitor_task = asyncio.create_task(self._watch_for_disconnect(monitor_socket))

# If the _watch_channel task gets cancelled from a .unsubscribe_from_topic call,
# the two child tasks won't automatically be cancelled. Store these up at the class
# level so that _cleanup_topic_subscription can cancel them.
self.channel_tasks[topic_name].append(monitor_task)
self.channel_tasks[topic_name].append(message_task)

# Await the monitor and message tasks. Message task should run forever.
# If the monitor task returns then it means the socket was disconnected,
# presumably from receiving a message larger than max message size.
done, pending = await asyncio.wait(
[monitor_task, message_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
for task in done:
if task.exception():
raise task.exception()

logger.info(f"Cycling topic {topic_name} after disconnect")
self.channel_tasks[topic_name].remove(monitor_task)
self.channel_tasks[topic_name].remove(message_task)

# Emit an event so that callbacks registered to pickup the disconnect can do things like
# send user-facing messages that an output stream was too big and won't be displayed
self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT)
channel_obj.close()

# Setting jupyter_client._iopub_channel to None will cause the next reference to
# the jupyter_client.iopub_channel @property to reconnect the socket.
# (see top of this while loop!)
setattr(self._client, f"_{channel_name}", None)

async def _watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel):
"""
Read in any messages on a specific jupyter_client channel and drop them into the inbound
worker queue which will trigger registered callback functions by predicate / topic
"""
while True:
msg: dict = await channel_obj.get_msg()
self.schedule_for_delivery(topic_name, msg)

async def _watch_for_disconnect(self, monitor_socket: zmq.Socket):
"""
Override base Manager _poll_loop to switch the final asyncio.sleep from 0 to
something more than that (definable at init or after instantiation, default 0.005).
While observing JupyterManager in real world, containers are using 100% CPU.
Possibly due to this loop being too aggressive?
An awaitable task that ends when a particular socket has a disconnect event. Used in
conjunction with watch_for_channel_messages to cycle a socket when it's disconnected.
"""
while True:
try:
await self._poll()
except Exception:
logger.exception("Uncaught exception encountered while polling backend")
finally:
await asyncio.sleep(self.sleep_between_polls)
msg: dict = await recv_monitor_message(monitor_socket)
event: zmq.Event = msg["event"]
if event == zmq.EVENT_DISCONNECTED:
return
27 changes: 14 additions & 13 deletions tests/test_jupyter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ async def run_until_seen(self, msg_types: List[str], timeout: float):
to see 'status', 'execute_reply', and 'status' again.
"""
deadline = time.time() + timeout
msg_types = msg_types[:]
while msg_types:
to_observe = msg_types[:]
while to_observe:
max_wait = deadline - time.time()
await asyncio.wait_for(self.next_event.wait(), timeout=max_wait)
if self.last_seen_event["msg_type"] in msg_types:
msg_types.remove(self.last_seen_event["msg_type"])
try:
await asyncio.wait_for(self.next_event.wait(), timeout=max_wait)
except asyncio.TimeoutError:
raise Exception(
f"Did not see the expected messages in time.\nTimeout: {timeout}\nExpected messages: {msg_types}\nUnobserved: {to_observe}" # noqa: E501
)

if self.last_seen_event["msg_type"] in to_observe:
to_observe.remove(self.last_seen_event["msg_type"])
self.next_event.clear()


Expand Down Expand Up @@ -115,7 +121,6 @@ async def test_jupyter_backend(self, mocker, ipykernel):
iopub_cb.assert_not_called()
shell_cb.assert_called_once()

@pytest.mark.xfail(reason="Cycling sockets is buggy in the current implementation")
async def test_reconnection(self, mocker, ipykernel):
"""
Test that if a message over the zmq channel is too large, we won't receive it
Expand Down Expand Up @@ -143,18 +148,14 @@ async def test_reconnection(self, mocker, ipykernel):
# status going to busy, execute_input, then a disconnect event where we would normally
# see a stream. The iopub channel should cycle, and hopefully catch the status going
# idle. We'll also see execute_reply on shell channel.
# (removed one status and execute_input from expected list because ci/cd seems to miss
# it often. Not sure why, runs fine locally)
mgr.send(
"shell",
"execute_request",
{"code": "print('x' * 2**13)", "silent": False},
)
try:
await monitor.run_until_seen(
msg_types=["status", "execute_input", "execute_reply", "status"], timeout=3
)
except asyncio.TimeoutError:
await mgr.shutdown()
raise Exception("Did not see the expected messages after cycling the iopub channel")
await monitor.run_until_seen(msg_types=["execute_reply", "status"], timeout=3)

disconnect_event.assert_called()

Expand Down