From 14f0c8481555f8603ea4bb27a20bf813e70265c1 Mon Sep 17 00:00:00 2001 From: Sylvain Corlay Date: Sat, 12 Dec 2020 10:52:15 +0100 Subject: [PATCH] Nudge kernel with info request until we receive IOPub messages --- jupyter_server/services/kernels/handlers.py | 147 ++++++++++++++++++-- 1 file changed, 136 insertions(+), 11 deletions(-) diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index e581c30471..0b2d102b94 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -10,7 +10,7 @@ import logging from textwrap import dedent -from tornado import web +from tornado import web, gen from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -122,11 +122,122 @@ def __repr__(self): def create_stream(self): km = self.kernel_manager identity = self.session.bsession - for channel in ('shell', 'control', 'iopub', 'stdin'): + for channel in ('iopub', 'shell', 'control', 'stdin'): meth = getattr(km, 'connect_' + channel) self.channels[channel] = stream = meth(self.kernel_id, identity=identity) stream.channel = channel + def nudge(self): + """Nudge the zmq connections with kernel_info_requests + Returns a Future that will resolve when we have received + a shell reply and at least one iopub message, + ensuring that zmq subscriptions are established, + sockets are fully connected, and kernel is responsive. + Keeps retrying kernel_info_request until these are both received. + """ + kernel = self.kernel_manager.get_kernel(self.kernel_id) + + # Do not nudge busy kernels as kernel info requests sent to shell are + # queued behind execution requests. + # nudging in this case would cause a potentially very long wait + # before connections are opened, + # plus it is *very* unlikely that a busy kernel will not finish + # establishing its zmq subscriptions before processing the next request. + if getattr(kernel, "execution_state") == "busy": + self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) + f = Future() + f.set_result(None) + return f + + # Use a transient shell channel to prevent leaking + # shell responses to the front-end. + shell_channel = kernel.connect_shell() + # The IOPub used by the client, whose subscriptions we are verifying. + iopub_channel = self.channels["iopub"] + + info_future = Future() + iopub_future = Future() + both_done = gen.multi([info_future, iopub_future]) + + def finish(f=None): + """Ensure all futures are resolved + which in turn triggers cleanup + """ + for f in (info_future, iopub_future): + if not f.done(): + f.set_result(None) + + def cleanup(f=None): + """Common cleanup""" + loop.remove_timeout(nudge_handle) + iopub_channel.stop_on_recv() + if not shell_channel.closed(): + shell_channel.close() + + # trigger cleanup when both message futures are resolved + both_done.add_done_callback(cleanup) + + def on_shell_reply(msg): + self.log.debug("Nudge: shell info reply received: %s", self.kernel_id) + if not info_future.done(): + self.log.debug("Nudge: resolving shell future: %s", self.kernel_id) + info_future.set_result(None) + + def on_iopub(msg): + self.log.debug("Nudge: IOPub received: %s", self.kernel_id) + if not iopub_future.done(): + iopub_channel.stop_on_recv() + self.log.debug("Nudge: resolving iopub future: %s", self.kernel_id) + iopub_future.set_result(None) + + iopub_channel.on_recv(on_iopub) + shell_channel.on_recv(on_shell_reply) + loop = IOLoop.current() + + # Nudge the kernel with kernel info requests until we get an IOPub message + def nudge(count): + count += 1 + + # NOTE: this close check appears to never be True during on_open, + # even when the peer has closed the connection + if self.ws_connection is None or self.ws_connection.is_closing(): + self.log.debug( + "Nudge: cancelling on closed websocket: %s", self.kernel_id + ) + finish() + return + + # check for stopped kernel + if self.kernel_id not in self.kernel_manager: + self.log.debug( + "Nudge: cancelling on stopped kernel: %s", self.kernel_id + ) + finish() + return + + # check for closed zmq socket + if shell_channel.closed(): + self.log.debug( + "Nudge: cancelling on closed zmq socket: %s", self.kernel_id + ) + finish() + return + + if not both_done.done(): + log = self.log.warning if count % 10 == 0 else self.log.debug + log("Nudge: attempt %s on kernel %s" % (count, self.kernel_id)) + self.session.send(shell_channel, "kernel_info_request") + nonlocal nudge_handle + nudge_handle = loop.call_later(0.5, nudge, count) + + nudge_handle = loop.call_later(0, nudge, count=0) + + # resolve with a timeout if we get no response + future = gen.with_timeout(loop.time() + self.kernel_info_timeout, both_done) + # ensure we have no dangling resources or unresolved Futures in case of timeout + future.add_done_callback(finish) + return future + def request_kernel_info(self): """send a request for kernel_info""" km = self.kernel_manager @@ -249,7 +360,7 @@ async def _register_session(self): await stale_handler.close() self._open_sessions[self.session_key] = self - def open(self, kernel_id): + async def open(self, kernel_id): super(ZMQChannelsHandler, self).open() km = self.kernel_manager km.notify_connect(kernel_id) @@ -259,15 +370,23 @@ def open(self, kernel_id): if buffer_info and buffer_info['session_key'] == self.session_key: self.log.info("Restoring connection for %s", self.session_key) self.channels = buffer_info['channels'] - replay_buffer = buffer_info['buffer'] - if replay_buffer: - self.log.info("Replaying %s buffered messages", len(replay_buffer)) - for channel, msg_list in replay_buffer: - stream = self.channels[channel] - self._on_zmq_reply(stream, msg_list) + + connected = self.nudge() + + def replay(value): + replay_buffer = buffer_info['buffer'] + if replay_buffer: + self.log.info("Replaying %s buffered messages", len(replay_buffer)) + for channel, msg_list in replay_buffer: + stream = self.channels[channel] + self._on_zmq_reply(stream, msg_list) + + + connected.add_done_callback(replay) else: try: self.create_stream() + connected = self.nudge() except web.HTTPError as e: self.log.error("Error opening stream: %s", e) # WebSockets don't response to traditional error codes so we @@ -281,8 +400,14 @@ def open(self, kernel_id): km.add_restart_callback(self.kernel_id, self.on_kernel_restarted) km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead') - for channel, stream in self.channels.items(): - stream.on_recv_stream(self._on_zmq_reply) + def subscribe(value): + for channel, stream in self.channels.items(): + stream.on_recv_stream(self._on_zmq_reply) + + connected.add_done_callback(subscribe) + + return connected + def on_message(self, msg): if not self.channels: