Skip to content

Commit

Permalink
comments, docstrings, reordering code to be more intuitive to new rea…
Browse files Browse the repository at this point in the history
…ders
  • Loading branch information
Kafonek committed Oct 3, 2022
1 parent 3377bfc commit 3e6ce75
Showing 1 changed file with 114 additions and 70 deletions.
184 changes: 114 additions & 70 deletions sending/backends/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jupyter_client import AsyncKernelClient
from jupyter_client.channels import ZMQSocketChannel
import jupyter_client.session
import zmq

from zmq.asyncio import Context
Expand All @@ -22,7 +23,12 @@ def __init__(
):
super().__init__()
self.connection_info = connection_info
# If max_message_size is set, we'll disconnect (and reconnect immediately) to zmq
# channels that try to send a message greater than that size. It prevents applications
# from OOM crashing reading in large outputs or other messages
self.max_message_size = max_message_size
# Tasks that ultiamtely watch the zmq channels for messages. Keep track of these
# for cleanup (unsubscribe_from_topic, shutdown)
self.channel_tasks: Dict[str, List[asyncio.Task]] = collections.defaultdict(list)

async def initialize(
Expand All @@ -46,6 +52,8 @@ async def initialize(
)

async def shutdown(self, now=False):
# Cancelling channel watching tasks here is equivalent to shutting down poll worker
# in other Sending backend implementations.
for topic_name, task_list in self.channel_tasks.items():
for task in task_list:
task.cancel()
Expand All @@ -56,32 +64,84 @@ async def shutdown(self, now=False):
def set_context_option(self, option: int, val: Union[int, bytes]):
self._context.setsockopt(option, val)

async def watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel):
def send(
self,
topic_name: str,
msg_type: str,
content: Optional[dict],
parent: Optional[dict] = None,
header: Optional[dict] = None,
metadata: Optional[dict] = None,
):
"""
Read in any messages on a specific jupyter_client channel and drop them into the inbound
worker queue which will trigger registered callback functions by predicate / topic
Put a message onto the outbound queue which will be picked up by the outbound
worker and sent over zmq to the Kernel. Most messages will get sent over the shell
channel, although some may go over control as wel.
Example:
mgr.send("shell", "execute_request", {"code": "print('hello')", "silent": False})
"""
while True:
msg: dict = await channel_obj.get_msg()
self.schedule_for_delivery(topic_name, msg)
# format the message into a Jupyter specced dictionary then drop into outbound queue
# to get sent over the wire when outbound worker calls ._publish
jupyter_session: jupyter_client.session.Session = self._client.session
jupyter_msg: dict = jupyter_session.msg(msg_type, content, parent, header, metadata)
self.outbound_queue.put_nowait(QueuedMessage(topic_name, jupyter_msg, None))

async def watch_for_disconnect(self, monitor_socket: zmq.Socket):
async def _publish(self, message: QueuedMessage):
"""
An awaitable task that ends when a particular socket has a disconnect event. Used in
conjunction with watch_for_channel_messages to cycle a socket when it's disconnected.
When the outbound worker observes a message on the outbound queue, it will call this
method to actually send the message over the wire.
"""
while True:
msg: dict = await recv_monitor_message(monitor_socket)
event: zmq.Event = msg["event"]
if event == zmq.EVENT_DISCONNECTED:
return
topic_name = message.topic
if topic_name not in self.subscribed_topics:
await self._create_topic_subscription(topic_name)
if hasattr(self._client, f"{topic_name}_channel"):
channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel")
channel_obj.send(message.contents)

# Normally in Sending backends there is the concept of a poll_worker which calls into _poll
# as part of a custom _poll_loop implementation. The poll_worker is what reads data over the
# wire (redis, socket, websocket, etc. zmq in the case of Jupyter/ipykernel). However the way
# this backend is written, reading data from zmq after subscribe_to_topic is called is handled
# by _watch_channel task (and its child tasks). poll_worker and these _poll methods do nothing.
async def _poll(self):
pass

async def _poll_loop(self):
pass

async def watch_channel(self, topic_name: str):
async def _create_topic_subscription(self, topic_name: str):
"""
Start observing messages on a zmq channel after a call to mgr.subscribe_to_topic('iopub')
"""
task = asyncio.create_task(self._watch_channel(topic_name))
self.channel_tasks[topic_name].append(task)

async def _cleanup_topic_subscription(self, topic_name: str):
"""
Clean up channel observing tasks after a call to mgr.unsubscribe_from_topic('iopub')
"""
if topic_name in self.channel_tasks:
for task in self.channel_tasks[topic_name]:
task.cancel()
self.channel_tasks[topic_name].remove(task)
await asyncio.sleep(0)
# Reset the channel object on our jupyter_client
setattr(self._client, f"_{topic_name}_channel", None)
else:
logger.warning(
f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict"
)

async def _watch_channel(self, topic_name: str):
"""
When a user subscribes to a topic (mgr.subscribe_to_topic('iopub')), this function starts
two tasks:
two child tasks:
1. Pull messages off the zmq channel and trigger any registered callbacks
2. Watch the monitor socket for disconnect events and reconnect / restart tasks
If a disconnect is observed, the two tasks are both cancelled and restarted.
Unsubscribing from a topic cancels this task and the child tasks.
"""
channel_name = f"{topic_name}_channel"

Expand All @@ -93,21 +153,25 @@ async def watch_channel(self, topic_name: str):
await self.context_hook()
while True:
# The channel properties (e.g. self._client.iopub_channel) will connect the socket
# if self._client._iopub_channel is None.
# if self._client._iopub_channel is None. Channel objects have a monitor object
# to observe lifecycle of the socket such as handshake / disconnect
channel_obj: ZMQSocketChannel = getattr(self._client, channel_name)
monitor_socket = channel_obj.socket.get_monitor_socket()
monitor_task = asyncio.create_task(self.watch_for_disconnect(monitor_socket))
message_task = asyncio.create_task(
self.watch_for_channel_messages(topic_name, channel_obj)
self._watch_for_channel_messages(topic_name, channel_obj)
)

# add tasks to self.channel_tasks so we can cleanup during topic unsubscribe / shutdown
monitor_socket = channel_obj.socket.get_monitor_socket()
monitor_task = asyncio.create_task(self._watch_for_disconnect(monitor_socket))

# If the _watch_channel task gets cancelled from a .unsubscribe_from_topic call,
# the two child tasks won't automatically be cancelled. Store these up at the class
# level so that _cleanup_topic_subscription can cancel them.
self.channel_tasks[topic_name].append(monitor_task)
self.channel_tasks[topic_name].append(message_task)

# Run the monitor and message tasks. Message task should run forever.
# If the monitor task returns then it means the socket was disconnected
# (max message size) and we need to cycle it.
# Await the monitor and message tasks. Message task should run forever.
# If the monitor task returns then it means the socket was disconnected,
# presumably from receiving a message larger than max message size.
done, pending = await asyncio.wait(
[monitor_task, message_task],
return_when=asyncio.FIRST_COMPLETED,
Expand All @@ -118,56 +182,36 @@ async def watch_channel(self, topic_name: str):
if task.exception():
raise task.exception()

logger.info(f"Cycling topic {topic_name} after disconnect")
self.channel_tasks[topic_name].remove(monitor_task)
self.channel_tasks[topic_name].remove(message_task)
logger.info(f"Cycling topic {topic_name} after disconnect")

# Emit an event so that callbacks registered to pickup the disconnect can do things like
# send user-facing messages that an output stream was too big and won't be displayed
self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT)
channel_obj.close()
setattr(self._client, f"_{channel_name}", None)

async def _create_topic_subscription(self, topic_name: str):
task = asyncio.create_task(self.watch_channel(topic_name))
self.channel_tasks[topic_name].append(task)

async def _cleanup_topic_subscription(self, topic_name: str):
if topic_name in self.channel_tasks:
for task in self.channel_tasks[topic_name]:
task.cancel()
self.channel_tasks[topic_name].remove(task)
await asyncio.sleep(0)
# Reset the channel object on our jupyter_client
setattr(self._client, f"_{topic_name}_channel", None)
else:
logger.warning(
f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict"
)

def send(
self,
topic_name: str,
msg_type: str,
content: Optional[dict],
parent: Optional[dict] = None,
header: Optional[dict] = None,
metadata: Optional[dict] = None,
):
msg = self._client.session.msg(msg_type, content, parent, header, metadata)
self.outbound_queue.put_nowait(QueuedMessage(topic_name, msg, None))

async def _publish(self, message: QueuedMessage):
topic_name = message.topic
if topic_name not in self.subscribed_topics:
await self._create_topic_subscription(topic_name)
if hasattr(self._client, f"{topic_name}_channel"):
channel_obj = getattr(self._client, f"{topic_name}_channel")
channel_obj.send(message.contents)
# Setting jupyter_client._iopub_channel to None will cause the next reference to
# the jupyter_client.iopub_channel @property to reconnect the socket.
# (see top of this while loop!)
setattr(self._client, f"_{channel_name}", None)

# _poll and _poll_loop are designed to be used to define how a Sending backend
# will read incoming data over the wire (socket, websocket, etc). In this implementation
# when we subscribe to a topic, it starts a watch_channel task which handles reading
# data over the right jupyter_client / zmq channel. So _poll and _poll_loop aren't used.
async def _poll(self):
pass
async def _watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel):
"""
Read in any messages on a specific jupyter_client channel and drop them into the inbound
worker queue which will trigger registered callback functions by predicate / topic
"""
while True:
msg: dict = await channel_obj.get_msg()
self.schedule_for_delivery(topic_name, msg)

async def _poll_loop(self):
pass
async def _watch_for_disconnect(self, monitor_socket: zmq.Socket):
"""
An awaitable task that ends when a particular socket has a disconnect event. Used in
conjunction with watch_for_channel_messages to cycle a socket when it's disconnected.
"""
while True:
msg: dict = await recv_monitor_message(monitor_socket)
event: zmq.Event = msg["event"]
if event == zmq.EVENT_DISCONNECTED:
return

0 comments on commit 3e6ce75

Please sign in to comment.