From c5139f4e34102555dc1398d6d090bb80a803a57a Mon Sep 17 00:00:00 2001 From: Kafonek Date: Fri, 30 Sep 2022 12:18:44 -0400 Subject: [PATCH 1/7] ZMQ socket monitoring refactor --- CHANGELOG.md | 1 + sending/backends/jupyter.py | 144 ++++++++++++++++++---------------- tests/test_jupyter_backend.py | 2 - 3 files changed, 77 insertions(+), 70 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6467c3d..4a99eb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `managed_service_fixtures` for Redis tests - `WebsocketManager` backend uses vanilla `logging` instead of `structlog`, remove need for `structlog` dependency once `managed-service-fixtures` also drops it - `JupyterBackend` introduce a short sleep in its poll loop while investigating 100% CPU usage +- `JupyterBackend` zmq polling changed fairly significantly to avoid missing messages while reconnecting socket after a max message size disconnect ## [0.2.2] - 2022-07-28 ### Changed diff --git a/sending/backends/jupyter.py b/sending/backends/jupyter.py index 5016d32..7f9ba76 100644 --- a/sending/backends/jupyter.py +++ b/sending/backends/jupyter.py @@ -1,10 +1,11 @@ import asyncio -from queue import Empty -from typing import Optional, Union +import collections +from typing import List, Optional, Union, Dict from jupyter_client import AsyncKernelClient from jupyter_client.channels import ZMQSocketChannel -from zmq import NOBLOCK, Event, Socket, SocketOption, pyzmq_version, zmq_version +import zmq + from zmq.asyncio import Context from zmq.utils.monitor import recv_monitor_message @@ -18,21 +19,21 @@ def __init__( connection_info: dict, *, max_message_size: int = None, - sleep_between_polls: float = 0.005, ): super().__init__() self.connection_info = connection_info - self._monitor_sockets_for_topic: dict[str, Socket] = {} self.max_message_size = max_message_size - self.sleep_between_polls = sleep_between_polls + self.channel_tasks: Dict[str, List[asyncio.Task]] = collections.defaultdict(list) async def initialize( self, *, queue_size=0, inbound_workers=1, outbound_workers=1, poll_workers=1 ): - logger.debug(f"Initializing Jupyter Kernel Manager: {zmq_version()=}, {pyzmq_version()=}") + logger.debug( + f"Initializing Jupyter Kernel Manager: {zmq.zmq_version()=}, {zmq.pyzmq_version()=}" + ) self._context = Context() if self.max_message_size: - self.set_context_option(SocketOption.MAXMSGSIZE, self.max_message_size) + self.set_context_option(zmq.SocketOption.MAXMSGSIZE, self.max_message_size) self._client = AsyncKernelClient(context=self._context) self._client.load_connection_info(self.connection_info) @@ -45,6 +46,9 @@ async def initialize( ) async def shutdown(self, now=False): + for topic_name, task_list in self.channel_tasks.items(): + for task in task_list: + task.cancel() await super().shutdown(now) # https://github.com/zeromq/pyzmq/issues/1003 self._context.destroy(linger=0) @@ -52,24 +56,73 @@ async def shutdown(self, now=False): def set_context_option(self, option: int, val: Union[int, bytes]): self._context.setsockopt(option, val) - async def _create_topic_subscription(self, topic_name: str): - if hasattr(self._client, f"{topic_name}_channel"): - channel_obj = getattr(self._client, f"{topic_name}_channel") - channel_obj.start() + async def watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + while True: + msg: dict = await channel_obj.get_msg() + print(msg) + self.schedule_for_delivery(topic_name, msg) - monitor_socket = channel_obj.socket.get_monitor_socket() - self._monitor_sockets_for_topic[topic_name] = monitor_socket + async def watch_for_disconnect(self, monitor_socket: zmq.Socket): + while True: + msg: dict = await recv_monitor_message(monitor_socket) + event: zmq.Event = msg["event"] + if event == zmq.EVENT_DISCONNECTED: + return - async def _cleanup_topic_subscription(self, topic_name: str): - if hasattr(self._client, f"{topic_name}_channel"): - channel_obj = getattr(self._client, f"{topic_name}_channel") - channel_obj.socket.disable_monitor() + async def watch_channel(self, topic_name: str): + channel_name = f"{topic_name}_channel" + while True: + # The channel properties (e.g. self._client.iopub_channel) will connect the socket + # if self._client._iopub_channel is None. + channel_obj: ZMQSocketChannel = getattr(self._client, channel_name) + monitor_socket = channel_obj.socket.get_monitor_socket() + monitor_task = asyncio.create_task(self.watch_for_disconnect(monitor_socket)) + message_task = asyncio.create_task( + self.watch_for_channel_messages(topic_name, channel_obj) + ) + + # add tasks to self.channel_tasks so we can cleanup during topic unsubscribe / shutdown + self.channel_tasks[topic_name].append(monitor_task) + self.channel_tasks[topic_name].append(message_task) + + # Run the monitor and message tasks. Message task should run forever. + # If the monitor task returns then it means the socket was disconnected + # (max message size) and we need to cycle it. + done, pending = await asyncio.wait( + [monitor_task, message_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + for task in done: + if task.exception(): + raise task.exception() + + self.channel_tasks[topic_name].remove(monitor_task) + self.channel_tasks[topic_name].remove(message_task) + logger.info(f"Cycling topic {topic_name} after disconnect") + print(f"Cycling topic {topic_name} after disconnect") + self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT) channel_obj.close() + setattr(self._client, f"_{channel_name}", None) - # Reset the underlying channel object so jupyter_client will recreate it - # if we subscribe to this again. + async def _create_topic_subscription(self, topic_name: str): + task = asyncio.create_task(self.watch_channel(topic_name)) + self.channel_tasks[topic_name].append(task) + + async def _cleanup_topic_subscription(self, topic_name: str): + print(f"Cleaning up topic {topic_name}") + if topic_name in self.channel_tasks: + for task in self.channel_tasks[topic_name]: + task.cancel() + self.channel_tasks[topic_name].remove(task) + await asyncio.sleep(0) + # Reset the channel object on our jupyter_client setattr(self._client, f"_{topic_name}_channel", None) - del self._monitor_sockets_for_topic[topic_name] + else: + logger.warning( + f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict" + ) def send( self, @@ -91,53 +144,8 @@ async def _publish(self, message: QueuedMessage): channel_obj = getattr(self._client, f"{topic_name}_channel") channel_obj.send(message.contents) - def _cycle_socket(self, topic): - channel_obj = getattr(self._client, f"{topic}_channel") - channel_obj.socket.disable_monitor() - channel_obj.close() - connect_fn = getattr(self._client, f"connect_{topic}") - channel_obj.socket = connect_fn() - monitor_socket = channel_obj.socket.get_monitor_socket() - self._monitor_sockets_for_topic[topic] = monitor_socket - async def _poll(self): - for topic_name in self.subscribed_topics: - channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel") - - while True: - try: - msg = await channel_obj.get_msg(timeout=0) - self.schedule_for_delivery(topic_name, msg) - except Empty: - break - - topics_to_cycle = [] - for topic, socket in self._monitor_sockets_for_topic.items(): - while await socket.poll(0): - msg = await recv_monitor_message(socket, flags=NOBLOCK) - logger.debug(f"ZMQ event: {topic=} {msg['event']=} {msg['value']=}") - if msg["event"] == Event.DISCONNECTED: - self._emit_system_event(topic, SystemEvents.FORCED_DISCONNECT) - topics_to_cycle.append(topic) - - for topic in topics_to_cycle: - # If the ZMQ socket is disconnected, try cycling it - # This is helpful in situations where ZMQ disconnects peers - # when it violates some constraint such as the max message size. - logger.info(f"ZMQ disconnected for topic '{topic}', cycling socket") - self._cycle_socket(topic) + pass async def _poll_loop(self): - """ - Override base Manager _poll_loop to switch the final asyncio.sleep from 0 to - something more than that (definable at init or after instantiation, default 0.005). - While observing JupyterManager in real world, containers are using 100% CPU. - Possibly due to this loop being too aggressive? - """ - while True: - try: - await self._poll() - except Exception: - logger.exception("Uncaught exception encountered while polling backend") - finally: - await asyncio.sleep(self.sleep_between_polls) + pass diff --git a/tests/test_jupyter_backend.py b/tests/test_jupyter_backend.py index 2e180c4..2026787 100644 --- a/tests/test_jupyter_backend.py +++ b/tests/test_jupyter_backend.py @@ -115,7 +115,6 @@ async def test_jupyter_backend(self, mocker, ipykernel): iopub_cb.assert_not_called() shell_cb.assert_called_once() - @pytest.mark.xfail(reason="Cycling sockets is buggy in the current implementation") async def test_reconnection(self, mocker, ipykernel): """ Test that if a message over the zmq channel is too large, we won't receive it @@ -155,7 +154,6 @@ async def test_reconnection(self, mocker, ipykernel): except asyncio.TimeoutError: await mgr.shutdown() raise Exception("Did not see the expected messages after cycling the iopub channel") - disconnect_event.assert_called() # Prove that after cycling the socket, normal executions work the same as always From 3377bfc66558c2c1bb6c3c0d6a84b5ed93313f45 Mon Sep 17 00:00:00 2001 From: Kafonek Date: Fri, 30 Sep 2022 13:49:23 -0400 Subject: [PATCH 2/7] comments and docstrings --- sending/backends/jupyter.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/sending/backends/jupyter.py b/sending/backends/jupyter.py index 7f9ba76..b7da301 100644 --- a/sending/backends/jupyter.py +++ b/sending/backends/jupyter.py @@ -57,12 +57,19 @@ def set_context_option(self, option: int, val: Union[int, bytes]): self._context.setsockopt(option, val) async def watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + """ + Read in any messages on a specific jupyter_client channel and drop them into the inbound + worker queue which will trigger registered callback functions by predicate / topic + """ while True: msg: dict = await channel_obj.get_msg() - print(msg) self.schedule_for_delivery(topic_name, msg) async def watch_for_disconnect(self, monitor_socket: zmq.Socket): + """ + An awaitable task that ends when a particular socket has a disconnect event. Used in + conjunction with watch_for_channel_messages to cycle a socket when it's disconnected. + """ while True: msg: dict = await recv_monitor_message(monitor_socket) event: zmq.Event = msg["event"] @@ -70,7 +77,20 @@ async def watch_for_disconnect(self, monitor_socket: zmq.Socket): return async def watch_channel(self, topic_name: str): + """ + When a user subscribes to a topic (mgr.subscribe_to_topic('iopub')), this function starts + two tasks: + 1. Pull messages off the zmq channel and trigger any registered callbacks + 2. Watch the monitor socket for disconnect events and reconnect / restart tasks + """ channel_name = f"{topic_name}_channel" + + # context_hook (primarily for adding structlog contextvars) is normally called in base + # _poll_loop, so that it's applied to every read the poll worker does. For the Jupyter + # implementation, we don't use _poll_loop, all reads from zmq start with tasks here, + # and any contextvars set in this method will be picked up by tasks created here. + if self.context_hook: + await self.context_hook() while True: # The channel properties (e.g. self._client.iopub_channel) will connect the socket # if self._client._iopub_channel is None. @@ -101,7 +121,6 @@ async def watch_channel(self, topic_name: str): self.channel_tasks[topic_name].remove(monitor_task) self.channel_tasks[topic_name].remove(message_task) logger.info(f"Cycling topic {topic_name} after disconnect") - print(f"Cycling topic {topic_name} after disconnect") self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT) channel_obj.close() setattr(self._client, f"_{channel_name}", None) @@ -111,7 +130,6 @@ async def _create_topic_subscription(self, topic_name: str): self.channel_tasks[topic_name].append(task) async def _cleanup_topic_subscription(self, topic_name: str): - print(f"Cleaning up topic {topic_name}") if topic_name in self.channel_tasks: for task in self.channel_tasks[topic_name]: task.cancel() @@ -144,6 +162,10 @@ async def _publish(self, message: QueuedMessage): channel_obj = getattr(self._client, f"{topic_name}_channel") channel_obj.send(message.contents) + # _poll and _poll_loop are designed to be used to define how a Sending backend + # will read incoming data over the wire (socket, websocket, etc). In this implementation + # when we subscribe to a topic, it starts a watch_channel task which handles reading + # data over the right jupyter_client / zmq channel. So _poll and _poll_loop aren't used. async def _poll(self): pass From 3e6ce75d8049ab2a095bb03f387c327b1cd62aba Mon Sep 17 00:00:00 2001 From: Kafonek Date: Sat, 1 Oct 2022 10:06:42 -0400 Subject: [PATCH 3/7] comments, docstrings, reordering code to be more intuitive to new readers --- sending/backends/jupyter.py | 184 ++++++++++++++++++++++-------------- 1 file changed, 114 insertions(+), 70 deletions(-) diff --git a/sending/backends/jupyter.py b/sending/backends/jupyter.py index b7da301..0f10839 100644 --- a/sending/backends/jupyter.py +++ b/sending/backends/jupyter.py @@ -4,6 +4,7 @@ from jupyter_client import AsyncKernelClient from jupyter_client.channels import ZMQSocketChannel +import jupyter_client.session import zmq from zmq.asyncio import Context @@ -22,7 +23,12 @@ def __init__( ): super().__init__() self.connection_info = connection_info + # If max_message_size is set, we'll disconnect (and reconnect immediately) to zmq + # channels that try to send a message greater than that size. It prevents applications + # from OOM crashing reading in large outputs or other messages self.max_message_size = max_message_size + # Tasks that ultiamtely watch the zmq channels for messages. Keep track of these + # for cleanup (unsubscribe_from_topic, shutdown) self.channel_tasks: Dict[str, List[asyncio.Task]] = collections.defaultdict(list) async def initialize( @@ -46,6 +52,8 @@ async def initialize( ) async def shutdown(self, now=False): + # Cancelling channel watching tasks here is equivalent to shutting down poll worker + # in other Sending backend implementations. for topic_name, task_list in self.channel_tasks.items(): for task in task_list: task.cancel() @@ -56,32 +64,84 @@ async def shutdown(self, now=False): def set_context_option(self, option: int, val: Union[int, bytes]): self._context.setsockopt(option, val) - async def watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + def send( + self, + topic_name: str, + msg_type: str, + content: Optional[dict], + parent: Optional[dict] = None, + header: Optional[dict] = None, + metadata: Optional[dict] = None, + ): """ - Read in any messages on a specific jupyter_client channel and drop them into the inbound - worker queue which will trigger registered callback functions by predicate / topic + Put a message onto the outbound queue which will be picked up by the outbound + worker and sent over zmq to the Kernel. Most messages will get sent over the shell + channel, although some may go over control as wel. + + Example: + mgr.send("shell", "execute_request", {"code": "print('hello')", "silent": False}) """ - while True: - msg: dict = await channel_obj.get_msg() - self.schedule_for_delivery(topic_name, msg) + # format the message into a Jupyter specced dictionary then drop into outbound queue + # to get sent over the wire when outbound worker calls ._publish + jupyter_session: jupyter_client.session.Session = self._client.session + jupyter_msg: dict = jupyter_session.msg(msg_type, content, parent, header, metadata) + self.outbound_queue.put_nowait(QueuedMessage(topic_name, jupyter_msg, None)) - async def watch_for_disconnect(self, monitor_socket: zmq.Socket): + async def _publish(self, message: QueuedMessage): """ - An awaitable task that ends when a particular socket has a disconnect event. Used in - conjunction with watch_for_channel_messages to cycle a socket when it's disconnected. + When the outbound worker observes a message on the outbound queue, it will call this + method to actually send the message over the wire. """ - while True: - msg: dict = await recv_monitor_message(monitor_socket) - event: zmq.Event = msg["event"] - if event == zmq.EVENT_DISCONNECTED: - return + topic_name = message.topic + if topic_name not in self.subscribed_topics: + await self._create_topic_subscription(topic_name) + if hasattr(self._client, f"{topic_name}_channel"): + channel_obj: ZMQSocketChannel = getattr(self._client, f"{topic_name}_channel") + channel_obj.send(message.contents) + + # Normally in Sending backends there is the concept of a poll_worker which calls into _poll + # as part of a custom _poll_loop implementation. The poll_worker is what reads data over the + # wire (redis, socket, websocket, etc. zmq in the case of Jupyter/ipykernel). However the way + # this backend is written, reading data from zmq after subscribe_to_topic is called is handled + # by _watch_channel task (and its child tasks). poll_worker and these _poll methods do nothing. + async def _poll(self): + pass + + async def _poll_loop(self): + pass - async def watch_channel(self, topic_name: str): + async def _create_topic_subscription(self, topic_name: str): + """ + Start observing messages on a zmq channel after a call to mgr.subscribe_to_topic('iopub') + """ + task = asyncio.create_task(self._watch_channel(topic_name)) + self.channel_tasks[topic_name].append(task) + + async def _cleanup_topic_subscription(self, topic_name: str): + """ + Clean up channel observing tasks after a call to mgr.unsubscribe_from_topic('iopub') + """ + if topic_name in self.channel_tasks: + for task in self.channel_tasks[topic_name]: + task.cancel() + self.channel_tasks[topic_name].remove(task) + await asyncio.sleep(0) + # Reset the channel object on our jupyter_client + setattr(self._client, f"_{topic_name}_channel", None) + else: + logger.warning( + f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict" + ) + + async def _watch_channel(self, topic_name: str): """ When a user subscribes to a topic (mgr.subscribe_to_topic('iopub')), this function starts - two tasks: + two child tasks: 1. Pull messages off the zmq channel and trigger any registered callbacks 2. Watch the monitor socket for disconnect events and reconnect / restart tasks + + If a disconnect is observed, the two tasks are both cancelled and restarted. + Unsubscribing from a topic cancels this task and the child tasks. """ channel_name = f"{topic_name}_channel" @@ -93,21 +153,25 @@ async def watch_channel(self, topic_name: str): await self.context_hook() while True: # The channel properties (e.g. self._client.iopub_channel) will connect the socket - # if self._client._iopub_channel is None. + # if self._client._iopub_channel is None. Channel objects have a monitor object + # to observe lifecycle of the socket such as handshake / disconnect channel_obj: ZMQSocketChannel = getattr(self._client, channel_name) - monitor_socket = channel_obj.socket.get_monitor_socket() - monitor_task = asyncio.create_task(self.watch_for_disconnect(monitor_socket)) message_task = asyncio.create_task( - self.watch_for_channel_messages(topic_name, channel_obj) + self._watch_for_channel_messages(topic_name, channel_obj) ) - # add tasks to self.channel_tasks so we can cleanup during topic unsubscribe / shutdown + monitor_socket = channel_obj.socket.get_monitor_socket() + monitor_task = asyncio.create_task(self._watch_for_disconnect(monitor_socket)) + + # If the _watch_channel task gets cancelled from a .unsubscribe_from_topic call, + # the two child tasks won't automatically be cancelled. Store these up at the class + # level so that _cleanup_topic_subscription can cancel them. self.channel_tasks[topic_name].append(monitor_task) self.channel_tasks[topic_name].append(message_task) - # Run the monitor and message tasks. Message task should run forever. - # If the monitor task returns then it means the socket was disconnected - # (max message size) and we need to cycle it. + # Await the monitor and message tasks. Message task should run forever. + # If the monitor task returns then it means the socket was disconnected, + # presumably from receiving a message larger than max message size. done, pending = await asyncio.wait( [monitor_task, message_task], return_when=asyncio.FIRST_COMPLETED, @@ -118,56 +182,36 @@ async def watch_channel(self, topic_name: str): if task.exception(): raise task.exception() + logger.info(f"Cycling topic {topic_name} after disconnect") self.channel_tasks[topic_name].remove(monitor_task) self.channel_tasks[topic_name].remove(message_task) - logger.info(f"Cycling topic {topic_name} after disconnect") + + # Emit an event so that callbacks registered to pickup the disconnect can do things like + # send user-facing messages that an output stream was too big and won't be displayed self._emit_system_event(topic_name, SystemEvents.FORCED_DISCONNECT) channel_obj.close() - setattr(self._client, f"_{channel_name}", None) - async def _create_topic_subscription(self, topic_name: str): - task = asyncio.create_task(self.watch_channel(topic_name)) - self.channel_tasks[topic_name].append(task) - - async def _cleanup_topic_subscription(self, topic_name: str): - if topic_name in self.channel_tasks: - for task in self.channel_tasks[topic_name]: - task.cancel() - self.channel_tasks[topic_name].remove(task) - await asyncio.sleep(0) - # Reset the channel object on our jupyter_client - setattr(self._client, f"_{topic_name}_channel", None) - else: - logger.warning( - f"Got a call to cleanup topic {topic_name} but it wasn't in the channel_tasks dict" - ) - - def send( - self, - topic_name: str, - msg_type: str, - content: Optional[dict], - parent: Optional[dict] = None, - header: Optional[dict] = None, - metadata: Optional[dict] = None, - ): - msg = self._client.session.msg(msg_type, content, parent, header, metadata) - self.outbound_queue.put_nowait(QueuedMessage(topic_name, msg, None)) - - async def _publish(self, message: QueuedMessage): - topic_name = message.topic - if topic_name not in self.subscribed_topics: - await self._create_topic_subscription(topic_name) - if hasattr(self._client, f"{topic_name}_channel"): - channel_obj = getattr(self._client, f"{topic_name}_channel") - channel_obj.send(message.contents) + # Setting jupyter_client._iopub_channel to None will cause the next reference to + # the jupyter_client.iopub_channel @property to reconnect the socket. + # (see top of this while loop!) + setattr(self._client, f"_{channel_name}", None) - # _poll and _poll_loop are designed to be used to define how a Sending backend - # will read incoming data over the wire (socket, websocket, etc). In this implementation - # when we subscribe to a topic, it starts a watch_channel task which handles reading - # data over the right jupyter_client / zmq channel. So _poll and _poll_loop aren't used. - async def _poll(self): - pass + async def _watch_for_channel_messages(self, topic_name: str, channel_obj: ZMQSocketChannel): + """ + Read in any messages on a specific jupyter_client channel and drop them into the inbound + worker queue which will trigger registered callback functions by predicate / topic + """ + while True: + msg: dict = await channel_obj.get_msg() + self.schedule_for_delivery(topic_name, msg) - async def _poll_loop(self): - pass + async def _watch_for_disconnect(self, monitor_socket: zmq.Socket): + """ + An awaitable task that ends when a particular socket has a disconnect event. Used in + conjunction with watch_for_channel_messages to cycle a socket when it's disconnected. + """ + while True: + msg: dict = await recv_monitor_message(monitor_socket) + event: zmq.Event = msg["event"] + if event == zmq.EVENT_DISCONNECTED: + return From b07d39f52a4c677e164f8db77640cde1153aba5f Mon Sep 17 00:00:00 2001 From: Kafonek Date: Mon, 3 Oct 2022 09:54:50 -0400 Subject: [PATCH 4/7] isort --- sending/backends/jupyter.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sending/backends/jupyter.py b/sending/backends/jupyter.py index 0f10839..b54006f 100644 --- a/sending/backends/jupyter.py +++ b/sending/backends/jupyter.py @@ -1,12 +1,11 @@ import asyncio import collections -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union -from jupyter_client import AsyncKernelClient -from jupyter_client.channels import ZMQSocketChannel import jupyter_client.session import zmq - +from jupyter_client import AsyncKernelClient +from jupyter_client.channels import ZMQSocketChannel from zmq.asyncio import Context from zmq.utils.monitor import recv_monitor_message From 7498213d443760fc15805974ec3516059a7b2e9b Mon Sep 17 00:00:00 2001 From: Kafonek Date: Mon, 3 Oct 2022 11:14:05 -0400 Subject: [PATCH 5/7] better error message in test failure --- tests/test_jupyter_backend.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_jupyter_backend.py b/tests/test_jupyter_backend.py index 2026787..1f18556 100644 --- a/tests/test_jupyter_backend.py +++ b/tests/test_jupyter_backend.py @@ -39,12 +39,18 @@ async def run_until_seen(self, msg_types: List[str], timeout: float): to see 'status', 'execute_reply', and 'status' again. """ deadline = time.time() + timeout - msg_types = msg_types[:] - while msg_types: + to_observe = msg_types[:] + while to_observe: max_wait = deadline - time.time() - await asyncio.wait_for(self.next_event.wait(), timeout=max_wait) - if self.last_seen_event["msg_type"] in msg_types: - msg_types.remove(self.last_seen_event["msg_type"]) + try: + await asyncio.wait_for(self.next_event.wait(), timeout=max_wait) + except asyncio.TimeoutError: + raise Exception( + f"Did not see the expected messages in time.\nTimeout: {timeout}\nExpected messages: {msg_types}\nUnobserved: {to_observe}" # noqa: E501 + ) + + if self.last_seen_event["msg_type"] in to_observe: + to_observe.remove(self.last_seen_event["msg_type"]) self.next_event.clear() From 6715e50108ce047872cf37fd7c0017535ecf61f0 Mon Sep 17 00:00:00 2001 From: Kafonek Date: Mon, 3 Oct 2022 11:45:15 -0400 Subject: [PATCH 6/7] better error message in test failure --- tests/test_jupyter_backend.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/test_jupyter_backend.py b/tests/test_jupyter_backend.py index 1f18556..7ec88aa 100644 --- a/tests/test_jupyter_backend.py +++ b/tests/test_jupyter_backend.py @@ -148,18 +148,14 @@ async def test_reconnection(self, mocker, ipykernel): # status going to busy, execute_input, then a disconnect event where we would normally # see a stream. The iopub channel should cycle, and hopefully catch the status going # idle. We'll also see execute_reply on shell channel. + # (removed execute_input from expected list because ci/cd seems to miss it often?) mgr.send( "shell", "execute_request", {"code": "print('x' * 2**13)", "silent": False}, ) - try: - await monitor.run_until_seen( - msg_types=["status", "execute_input", "execute_reply", "status"], timeout=3 - ) - except asyncio.TimeoutError: - await mgr.shutdown() - raise Exception("Did not see the expected messages after cycling the iopub channel") + await monitor.run_until_seen(msg_types=["status", "execute_reply", "status"], timeout=3) + disconnect_event.assert_called() # Prove that after cycling the socket, normal executions work the same as always From f9b95d74d84a521422f46adb421d2ce2ff89449e Mon Sep 17 00:00:00 2001 From: Kafonek Date: Mon, 3 Oct 2022 11:52:12 -0400 Subject: [PATCH 7/7] fighting ci/cd --- tests/test_jupyter_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_jupyter_backend.py b/tests/test_jupyter_backend.py index 7ec88aa..059f109 100644 --- a/tests/test_jupyter_backend.py +++ b/tests/test_jupyter_backend.py @@ -148,13 +148,14 @@ async def test_reconnection(self, mocker, ipykernel): # status going to busy, execute_input, then a disconnect event where we would normally # see a stream. The iopub channel should cycle, and hopefully catch the status going # idle. We'll also see execute_reply on shell channel. - # (removed execute_input from expected list because ci/cd seems to miss it often?) + # (removed one status and execute_input from expected list because ci/cd seems to miss + # it often. Not sure why, runs fine locally) mgr.send( "shell", "execute_request", {"code": "print('x' * 2**13)", "silent": False}, ) - await monitor.run_until_seen(msg_types=["status", "execute_reply", "status"], timeout=3) + await monitor.run_until_seen(msg_types=["execute_reply", "status"], timeout=3) disconnect_event.assert_called()