From f2fa6353f0077741b9b2846b9211eba2b59f32b5 Mon Sep 17 00:00:00 2001 From: Lionel Henry Date: Thu, 10 Oct 2024 11:47:12 +0200 Subject: [PATCH] Implement JEP 65 (#577) Closes https://github.com/posit-dev/ark/issues/569 This PR fixes a race condition regarding subscriptions to IOPub that causes clients to miss IOPub messages: - On startup a client connects to the server sockets of a kernel. - The client sends a request on Shell. - The kernel starts processing the request and emits busy on IOPub. If the client hasn't been able to fully subscribe to IOPub, messages can be lost, in particular the Busy message that encloses the request output. On the Positron side we fixed it by sending kernel-info requests in a loop until we get a Ready message on IOPub. This signals Positron that the kernel is fully connected and in the Ready state: https://github.com/posit-dev/positron/pull/2207. We haven't implemented a similar fix in our dummy clients for integration tests and we believe this is what is causing the race condition described in #569. As noted in https://github.com/posit-dev/positron/pull/2207, there is an accepted JEP proposal (JEP 65) that aims at solving this problem by switching to XPUB. https://jupyter.org/enhancement-proposals/65-jupyter-xpub/jupyter-xpub.html https://github.com/jupyter/enhancement-proposals/pull/65 The XPUB socket allows the server to get notified of all new subscriptions. A message of type `iopub_welcome` is sent to all connected clients. They should generally ignore it but clients that have just started up can use it as a cue that IOPub is correctly connected and that they won't miss any output from that point on. Approach: The subscription notification comes in as a message on the IOPub socket. This is problematic because the IOPub thread now needs to listens to its crossbeam channel and to the 0MQ socket at the same time, which isn't possible without resorting to timeout polling. So we use the same approach and infrastructure that we implemented in https://github.com/posit-dev/ark/pull/58 for listeing to both input replies on the StdIn socket and interrupt notifications on a crossbeam channel. The forwarding thread now owns the IOPub socket and listens for subscription notifications and fowrards IOPub messages coming from the kernel components. --- * Start moving IOPub messages to forwarding thread * Remove unused import * Resolve the roundabout `Message` problem The solution was to move the conversion to `JupyterMessage` up into the match, so we "know" what `T` is! * Use correct `Welcome` `MessageType` * Implement `SubscriptionMessage` support and switch to `XPUB` * The `Welcome` message doesn't come from ark * Use `amalthea::Result` * Add more comments --------- Co-authored-by: Davis Vaughan Co-authored-by: Lionel Henry --- crates/amalthea/src/comm/comm_manager.rs | 5 +- .../amalthea/src/fixtures/dummy_frontend.rs | 27 ++- crates/amalthea/src/kernel.rs | 90 +++++-- crates/amalthea/src/socket/iopub.rs | 225 ++++++++++++------ crates/amalthea/src/socket/socket.rs | 103 ++++---- crates/amalthea/src/wire/jupyter_message.rs | 58 +++-- crates/amalthea/src/wire/mod.rs | 2 + .../amalthea/src/wire/subscription_message.rs | 81 +++++++ crates/amalthea/src/wire/welcome.rs | 35 +++ 9 files changed, 466 insertions(+), 160 deletions(-) create mode 100644 crates/amalthea/src/wire/subscription_message.rs create mode 100644 crates/amalthea/src/wire/welcome.rs diff --git a/crates/amalthea/src/comm/comm_manager.rs b/crates/amalthea/src/comm/comm_manager.rs index 4a08ffb59..089f45f28 100644 --- a/crates/amalthea/src/comm/comm_manager.rs +++ b/crates/amalthea/src/comm/comm_manager.rs @@ -23,6 +23,7 @@ use crate::comm::event::CommManagerRequest; use crate::socket::comm::CommInitiator; use crate::socket::comm::CommSocket; use crate::socket::iopub::IOPubMessage; +use crate::wire::comm_close::CommClose; use crate::wire::comm_msg::CommWireMsg; use crate::wire::comm_open::CommOpen; use crate::wire::header::JupyterHeader; @@ -245,7 +246,9 @@ impl CommManager { } }, - CommMsg::Close => IOPubMessage::CommClose(comm_socket.comm_id.clone()), + CommMsg::Close => IOPubMessage::CommClose(CommClose { + comm_id: comm_socket.comm_id.clone(), + }), }; // Deliver the message to the frontend diff --git a/crates/amalthea/src/fixtures/dummy_frontend.rs b/crates/amalthea/src/fixtures/dummy_frontend.rs index 98f85eed1..eac700aea 100644 --- a/crates/amalthea/src/fixtures/dummy_frontend.rs +++ b/crates/amalthea/src/fixtures/dummy_frontend.rs @@ -162,10 +162,6 @@ impl DummyFrontend { ) .unwrap(); - // Subscribe to IOPub! Server is the one that sent us this port, - // so its already connected on its end. - iopub_socket.subscribe().unwrap(); - let stdin_socket = Socket::new( connection.session.clone(), connection.ctx.clone(), @@ -186,14 +182,19 @@ impl DummyFrontend { ) .unwrap(); - // TODO!: Without this sleep, `IOPub` `Busy` messages sporadically - // don't arrive when running integration tests. I believe this is a result - // of PUB sockets dropping messages while in a "mute" state (i.e. no subscriber - // connected yet). Even though we run `iopub_socket.subscribe()` to subscribe, - // it seems like we can return from this function even before our socket - // has fully subscribed, causing messages to get dropped. - // https://libzmq.readthedocs.io/en/latest/zmq_socket.html - std::thread::sleep(std::time::Duration::from_millis(500)); + // Subscribe to IOPub! Server's XPUB socket will receive a notification of + // our subscription with `subscription`, then will publish an IOPub `Welcome` + // message, sending back our `subscription`. + iopub_socket.subscribe(b"").unwrap(); + + // Immediately block until we've received the IOPub welcome message. + // This confirms that we've fully subscribed and avoids dropping any + // of the initial IOPub messages that a server may send if we start + // perform requests immediately. + // https://github.com/posit-dev/ark/pull/577 + assert_matches!(Self::recv(&iopub_socket), Message::Welcome(data) => { + assert_eq!(data.content.subscription, String::from("")); + }); Self { _control_socket, @@ -347,7 +348,7 @@ impl DummyFrontend { let msg = self.recv_iopub(); // Assert its type - let piece = assert_matches!(msg, Message::StreamOutput(data) => { + let piece = assert_matches!(msg, Message::Stream(data) => { assert_eq!(data.content.name, stream); data.content.text }); diff --git a/crates/amalthea/src/kernel.rs b/crates/amalthea/src/kernel.rs index 75dbdf385..36a0cc960 100644 --- a/crates/amalthea/src/kernel.rs +++ b/crates/amalthea/src/kernel.rs @@ -41,6 +41,7 @@ use crate::wire::jupyter_message::JupyterMessage; use crate::wire::jupyter_message::Message; use crate::wire::jupyter_message::OutboundMessage; use crate::wire::jupyter_message::Status; +use crate::wire::subscription_message::SubscriptionMessage; macro_rules! report_error { ($($arg:tt)+) => (if cfg!(debug_assertions) { panic!($($arg)+) } else { log::error!($($arg)+) }) @@ -118,20 +119,25 @@ pub fn connect( ) }); - // Create the IOPub PUB/SUB socket and start a thread to broadcast to + // Create the IOPub XPUB/SUB socket and start a thread to broadcast to // the client. IOPub only broadcasts messages, so it listens to other // threads on a Receiver instead of to the client. let iopub_socket = Socket::new( session.clone(), ctx.clone(), String::from("IOPub"), - zmq::PUB, + zmq::XPUB, None, connection_file.endpoint(connection_file.iopub_port), )?; let iopub_port = port_finalize(&iopub_socket, connection_file.iopub_port)?; + + let (iopub_inbound_tx, iopub_inbound_rx) = unbounded(); + let iopub_session = iopub_socket.session.clone(); + let iopub_outbound_tx = outbound_tx.clone(); + spawn!(format!("{name}-iopub"), move || { - iopub_thread(iopub_socket, iopub_rx) + iopub_thread(iopub_rx, iopub_inbound_rx, iopub_outbound_tx, iopub_session) }); // Create the heartbeat socket and start a thread to listen for @@ -165,11 +171,12 @@ pub fn connect( let (stdin_inbound_tx, stdin_inbound_rx) = unbounded(); let (stdin_interrupt_tx, stdin_interrupt_rx) = bounded(1); let stdin_session = stdin_socket.session.clone(); + let stdin_outbound_tx = outbound_tx.clone(); spawn!(format!("{name}-stdin"), move || { stdin_thread( stdin_inbound_rx, - outbound_tx, + stdin_outbound_tx, stdin_request_rx, stdin_reply_tx, stdin_interrupt_rx, @@ -224,6 +231,8 @@ pub fn connect( outbound_notif_socket_rx, stdin_socket, stdin_inbound_tx, + iopub_socket, + iopub_inbound_tx, outbound_rx_clone, ) }); @@ -338,8 +347,13 @@ fn shell_thread( } /// Starts the IOPub thread. -fn iopub_thread(socket: Socket, receiver: Receiver) -> Result<(), Error> { - let mut iopub = IOPub::new(socket, receiver); +fn iopub_thread( + rx: Receiver, + inbound_rx: Receiver>, + outbound_tx: Sender, + session: Session, +) -> Result<(), Error> { + let mut iopub = IOPub::new(rx, inbound_rx, outbound_tx, session); iopub.listen(); Ok(()) } @@ -367,10 +381,32 @@ fn stdin_thread( /// Starts the thread that forwards 0MQ messages to Amalthea channels /// and vice versa. +/// +/// This is a solution to the problem of polling/selecting from 0MQ sockets and +/// crossbeam channels at the same time. Message events on crossbeam channels +/// are emitted by the notifier thread (see below) on a 0MQ socket. The +/// forwarding thread is then able to listen on 0MQ sockets (e.g. StdIn replies +/// and IOPub subscriptions) and the notification socket at the same time. +/// +/// Part of the problem this setup solves is that 0MQ sockets can only be owned +/// by one thread at a time. Take IOPUb as an example: we need to listen on that +/// socket for subscription events. We also need to listen for new IOPub +/// messages to send to the client, sent via Crossbeam channels. So we need at +/// least two threads listening for these two different kinds of events. But the +/// forwarding thread has to fully own the socket to be able to listen to it. So +/// it's also in charge of sending IOPub messages on that socket. When an IOPub +/// message comes in, the notifier thread wakes up the forwarding thread which +/// then pulls messages from the channel and forwards them to the IOPub socket. +/// +/// Terminology: +/// - Outbound means that a crossbeam message needs to be forwarded to a 0MQ socket. +/// - Inbound means that a 0MQ message needs to be forwarded to a crossbeam channel. fn zmq_forwarding_thread( outbound_notif_socket: Socket, stdin_socket: Socket, stdin_inbound_tx: Sender>, + iopub_socket: Socket, + iopub_inbound_tx: Sender>, outbound_rx: Receiver, ) { // This function checks for notifications that an outgoing message @@ -394,8 +430,8 @@ fn zmq_forwarding_thread( }; // This function checks that a 0MQ message from the frontend is ready. - let has_inbound = || -> bool { - match stdin_socket.socket.poll(zmq::POLLIN, 0) { + let has_inbound = |socket: &Socket| -> bool { + match socket.socket.poll(zmq::POLLIN, 0) { Ok(n) if n > 0 => true, _ => false, } @@ -409,6 +445,7 @@ fn zmq_forwarding_thread( let outbound_msg = outbound_rx.recv()?; match outbound_msg { OutboundMessage::StdIn(msg) => msg.send(&stdin_socket)?, + OutboundMessage::IOPub(msg) => msg.send(&iopub_socket)?, }; // Notify back @@ -419,9 +456,19 @@ fn zmq_forwarding_thread( // Forwards 0MQ message from the frontend to the corresponding // Amalthea channel. - let forward_inbound = || -> anyhow::Result<()> { - let msg = Message::read_from_socket(&stdin_socket); - stdin_inbound_tx.send(msg)?; + let forward_inbound = + |socket: &Socket, inbound_tx: &Sender>| -> anyhow::Result<()> { + let msg = Message::read_from_socket(socket); + inbound_tx.send(msg)?; + Ok(()) + }; + + // Forwards special 0MQ XPUB subscription message from the frontend to the IOPub thread. + let forward_inbound_subscription = |socket: &Socket, + inbound_tx: &Sender>| + -> anyhow::Result<()> { + let msg = SubscriptionMessage::read_from_socket(socket); + inbound_tx.send(msg)?; Ok(()) }; @@ -429,7 +476,8 @@ fn zmq_forwarding_thread( let mut poll_items = { let outbound_notif_poll_item = outbound_notif_socket.socket.as_poll_item(zmq::POLLIN); let stdin_poll_item = stdin_socket.socket.as_poll_item(zmq::POLLIN); - vec![outbound_notif_poll_item, stdin_poll_item] + let iopub_poll_item = iopub_socket.socket.as_poll_item(zmq::POLLIN); + vec![outbound_notif_poll_item, stdin_poll_item, iopub_poll_item] }; loop { @@ -450,9 +498,17 @@ fn zmq_forwarding_thread( continue; } - if has_inbound() { + if has_inbound(&stdin_socket) { + unwrap!( + forward_inbound(&stdin_socket, &stdin_inbound_tx), + Err(err) => report_error!("While forwarding inbound message: {}", err) + ); + continue; + } + + if has_inbound(&iopub_socket) { unwrap!( - forward_inbound(), + forward_inbound_subscription(&iopub_socket, &iopub_inbound_tx), Err(err) => report_error!("While forwarding inbound message: {}", err) ); continue; @@ -463,8 +519,10 @@ fn zmq_forwarding_thread( } } -/// Starts the thread that notifies the forwarding thread that new -/// outgoing messages have arrived from Amalthea. +/// Starts the thread that notifies the forwarding thread that new outgoing +/// messages have arrived from Amalthea channels. This wakes up the forwarding +/// thread which will then pop the message from the channel and forward them to +/// the relevant zeromq socket. fn zmq_notifier_thread(notif_socket: Socket, outbound_rx: Receiver) { let mut sel = Select::new(); sel.recv(&outbound_rx); diff --git a/crates/amalthea/src/socket/iopub.rs b/crates/amalthea/src/socket/iopub.rs index 337412498..cac630f02 100644 --- a/crates/amalthea/src/socket/iopub.rs +++ b/crates/amalthea/src/socket/iopub.rs @@ -11,11 +11,8 @@ use crossbeam::channel::tick; use crossbeam::channel::Receiver; use crossbeam::channel::Sender; use crossbeam::select; -use log::trace; -use log::warn; -use crate::error::Error; -use crate::socket::socket::Socket; +use crate::session::Session; use crate::wire::comm_close::CommClose; use crate::wire::comm_msg::CommWireMsg; use crate::wire::comm_open::CommOpen; @@ -25,19 +22,32 @@ use crate::wire::execute_input::ExecuteInput; use crate::wire::execute_result::ExecuteResult; use crate::wire::header::JupyterHeader; use crate::wire::jupyter_message::JupyterMessage; +use crate::wire::jupyter_message::Message; +use crate::wire::jupyter_message::OutboundMessage; use crate::wire::jupyter_message::ProtocolMessage; use crate::wire::status::ExecutionState; use crate::wire::status::KernelStatus; use crate::wire::stream::Stream; use crate::wire::stream::StreamOutput; +use crate::wire::subscription_message::SubscriptionKind; +use crate::wire::subscription_message::SubscriptionMessage; use crate::wire::update_display_data::UpdateDisplayData; +use crate::wire::welcome::Welcome; pub struct IOPub { - /// The underlying IOPub socket - socket: Socket, - /// A channel that receives IOPub messages from other threads - receiver: Receiver, + rx: Receiver, + + /// A channel that receives IOPub subscriber notifications from + /// the IOPub socket + inbound_rx: Receiver>, + + /// A channel to forward along IOPub messages to the IOPub socket + /// for delivery to the frontend + outbound_tx: Sender, + + /// ZMQ session used to create messages + session: Session, /// The current message context; attached to outgoing messages to pair /// outputs with the message that caused them. @@ -60,7 +70,7 @@ pub enum IOPubContextChannel { Control, } -/// Enumeration of all messages that can be delivered from the IOPub PUB/SUB +/// Enumeration of all messages that can be delivered from the IOPub XPUB/SUB /// socket. These messages generally are created on other threads and then sent /// via a channel to the IOPub thread. pub enum IOPubMessage { @@ -72,7 +82,7 @@ pub enum IOPubMessage { CommOpen(CommOpen), CommMsgReply(JupyterHeader, CommWireMsg), CommMsgEvent(CommWireMsg), - CommClose(String), + CommClose(CommClose), DisplayData(DisplayData), UpdateDisplayData(UpdateDisplayData), Wait(Wait), @@ -87,16 +97,23 @@ pub struct Wait { impl IOPub { /// Create a new IOPub socket wrapper. /// - /// * `socket` - The ZeroMQ socket that will deliver IOPub messages to - /// subscribed clients. - /// * `receiver` - The receiver channel that will receive IOPub + /// * `rx` - The receiver channel that will receive IOPub /// messages from other threads. - pub fn new(socket: Socket, receiver: Receiver) -> Self { + /// * `inbound_rx` - The receiver channel that will receive + /// new subscriber messages forwarded from the IOPub socket. + pub fn new( + rx: Receiver, + inbound_rx: Receiver>, + outbound_tx: Sender, + session: Session, + ) -> Self { let buffer = StreamBuffer::new(Stream::Stdout); Self { - socket, - receiver, + rx, + inbound_rx, + outbound_tx, + session, shell_context: None, control_context: None, buffer, @@ -115,18 +132,30 @@ impl IOPub { loop { select! { - recv(self.receiver) -> message => { + recv(self.rx) -> message => { match message { Ok(message) => { - if let Err(error) = self.process_message(message) { - warn!("Error delivering iopub message: {error:?}") + if let Err(error) = self.process_outbound_message(message) { + log::warn!("Error delivering outbound iopub message: {error:?}") } }, Err(error) => { - warn!("Failed to receive iopub message: {error:?}"); + log::warn!("Failed to receive outbound iopub message: {error:?}"); }, } }, + recv(self.inbound_rx) -> message => { + match message.unwrap() { + Ok(message) => { + if let Err(error) = self.process_inbound_message(message) { + log::warn!("Error processing inbound iopub message: {error:?}") + } + }, + Err(error) => { + log::warn!("Failed to receive inbound iopub message: {error:?}"); + } + } + }, recv(flush_interval) -> message => { match message { Ok(_) => self.flush_stream(), @@ -138,16 +167,16 @@ impl IOPub { } /// Process an IOPub message from another thread. - fn process_message(&mut self, message: IOPubMessage) -> Result<(), Error> { + fn process_outbound_message(&mut self, message: IOPubMessage) -> crate::Result<()> { match message { - IOPubMessage::Status(context, context_channel, msg) => { + IOPubMessage::Status(context, context_channel, content) => { // When we enter the Busy state as a result of a message, we // update the context. Future messages to IOPub name this // context in the parent header sent to the client; this makes // it possible for the client to associate events/output with // their originator without requiring us to thread the values // through the stack. - match (&context_channel, &msg.execution_state) { + match (&context_channel, &content.execution_state) { (IOPubContextChannel::Control, ExecutionState::Busy) => { self.control_context = Some(context.clone()); }, @@ -170,74 +199,124 @@ impl IOPub { }, } - self.send_message_with_header(context, msg) + self.forward(Message::Status(self.message_with_header(context, content))) }, - IOPubMessage::ExecuteResult(msg) => { + IOPubMessage::ExecuteResult(content) => { self.flush_stream(); - self.send_message_with_context(msg, IOPubContextChannel::Shell) + self.forward(Message::ExecuteResult( + self.message_with_context(content, IOPubContextChannel::Shell), + )) }, - IOPubMessage::ExecuteError(msg) => { + IOPubMessage::ExecuteError(content) => { self.flush_stream(); - self.send_message_with_context(msg, IOPubContextChannel::Shell) + self.forward(Message::ExecuteError( + self.message_with_context(content, IOPubContextChannel::Shell), + )) + }, + IOPubMessage::ExecuteInput(content) => self.forward(Message::ExecuteInput( + self.message_with_context(content, IOPubContextChannel::Shell), + )), + IOPubMessage::Stream(content) => self.process_stream_message(content), + IOPubMessage::CommOpen(content) => { + self.forward(Message::CommOpen(self.message(content))) + }, + IOPubMessage::CommMsgEvent(content) => { + self.forward(Message::CommMsg(self.message(content))) }, - IOPubMessage::ExecuteInput(msg) => { - self.send_message_with_context(msg, IOPubContextChannel::Shell) + IOPubMessage::CommMsgReply(header, content) => { + self.forward(Message::CommMsg(self.message_with_header(header, content))) }, - IOPubMessage::Stream(msg) => self.process_stream_message(msg), - IOPubMessage::CommOpen(msg) => self.send_message(msg), - IOPubMessage::CommMsgEvent(msg) => self.send_message(msg), - IOPubMessage::CommMsgReply(header, msg) => self.send_message_with_header(header, msg), - IOPubMessage::CommClose(comm_id) => self.send_message(CommClose { comm_id }), - IOPubMessage::DisplayData(msg) => { + IOPubMessage::CommClose(content) => { + self.forward(Message::CommClose(self.message(content))) + }, + IOPubMessage::DisplayData(content) => { self.flush_stream(); - self.send_message_with_context(msg, IOPubContextChannel::Shell) + self.forward(Message::DisplayData( + self.message_with_context(content, IOPubContextChannel::Shell), + )) }, - IOPubMessage::UpdateDisplayData(msg) => { + IOPubMessage::UpdateDisplayData(content) => { self.flush_stream(); - self.send_message_with_context(msg, IOPubContextChannel::Shell) + self.forward(Message::UpdateDisplayData( + self.message_with_context(content, IOPubContextChannel::Shell), + )) }, - IOPubMessage::Wait(msg) => self.process_wait_request(msg), + IOPubMessage::Wait(content) => self.process_wait_request(content), } } - /// Send a message using the underlying socket with the given content. + /// As an XPUB socket, the only inbound message that IOPub receives is + /// a subscription message that notifies us when a SUB subscribes or + /// unsubscribes. + /// + /// When we get a subscription notification, we forward along an IOPub + /// `Welcome` message back to the SUB, in compliance with JEP 65. Clients + /// that don't know how to process this `Welcome` message should just ignore it. + fn process_inbound_message(&self, message: SubscriptionMessage) -> crate::Result<()> { + let subscription = message.subscription; + + match message.kind { + SubscriptionKind::Subscribe => { + log::info!( + "Received subscribe message on IOPub with subscription '{subscription}'." + ); + let content = Welcome { subscription }; + self.forward(Message::Welcome(self.message(content))) + }, + SubscriptionKind::Unsubscribe => { + log::info!( + "Received unsubscribe message on IOPub with subscription '{subscription}'." + ); + // We don't do anything on unsubscribes + return Ok(()); + }, + } + } + + /// Create a message using the underlying socket with the given content. /// No parent is assumed. - fn send_message(&self, content: T) -> Result<(), Error> { - self.send_message_impl(None, content) + fn message(&self, content: T) -> JupyterMessage { + self.message_create(None, content) } - /// Send a message using the underlying socket with the given content. The + /// Create a message using the underlying socket with the given content. The /// parent message is assumed to be the current context. - fn send_message_with_context( + fn message_with_context( &self, content: T, context_channel: IOPubContextChannel, - ) -> Result<(), Error> { + ) -> JupyterMessage { let context = match context_channel { IOPubContextChannel::Control => &self.control_context, IOPubContextChannel::Shell => &self.shell_context, }; - self.send_message_impl(context.clone(), content) + self.message_create(context.clone(), content) } - /// Send a message using the underlying socket with the given content and + /// Create a message using the underlying socket with the given content and /// specific header. Used when the parent message is known by the message /// sender, typically in comm message replies. - fn send_message_with_header( + fn message_with_header( &self, header: JupyterHeader, content: T, - ) -> Result<(), Error> { - self.send_message_impl(Some(header), content) + ) -> JupyterMessage { + self.message_create(Some(header), content) } - fn send_message_impl( + fn message_create( &self, header: Option, content: T, - ) -> Result<(), Error> { - let msg = JupyterMessage::::create(content, header, &self.socket.session); - msg.send(&self.socket) + ) -> JupyterMessage { + JupyterMessage::::create(content, header, &self.session) + } + + /// Forward a message on to the actual IOPub socket through the outbound channel + fn forward(&self, message: Message) -> crate::Result<()> { + self.outbound_tx + .send(OutboundMessage::IOPub(message)) + .map_err(|err| crate::Error::SendError(format!("{err:?}"))) } /// Flushes the active stream, sending along the message if the buffer @@ -249,9 +328,12 @@ impl IOPub { return; } - let message = self.buffer.drain(); + let content = self.buffer.drain(); + + let message = + Message::Stream(self.message_with_context(content, IOPubContextChannel::Shell)); - let Err(error) = self.send_message_with_context(message, IOPubContextChannel::Shell) else { + let Err(error) = self.forward(message) else { // Message sent successfully return; }; @@ -261,7 +343,7 @@ impl IOPub { Stream::Stderr => "stderr", }; - warn!("Error delivering iopub 'stream' message over '{name}': {error:?}"); + log::warn!("Error delivering iopub 'stream' message over '{name}': {error:?}"); } /// Processes a `Stream` message by appending it to the stream buffer @@ -271,7 +353,7 @@ impl IOPub { /// /// If this new message switches streams, then we flush the existing stream /// before switching. - fn process_stream_message(&mut self, message: StreamOutput) -> Result<(), Error> { + fn process_stream_message(&mut self, message: StreamOutput) -> crate::Result<()> { if message.name != self.buffer.name { // Swap streams, but flush the existing stream first self.flush_stream(); @@ -294,25 +376,24 @@ impl IOPub { /// waiting for the queue to empty, it is possible for a message on a /// different socket that is sent after waiting to still get processed by /// the frontend before the messages we cleared from the IOPub queue. - fn process_wait_request(&mut self, message: Wait) -> Result<(), Error> { + fn process_wait_request(&mut self, message: Wait) -> crate::Result<()> { message.wait_tx.send(()).unwrap(); Ok(()) } /// Emits the given kernel state to the client. fn emit_state(&self, state: ExecutionState) { - trace!("Entering kernel state: {:?}", state); - if let Err(err) = JupyterMessage::::create( - KernelStatus { - execution_state: state, - }, - None, - &self.socket.session, - ) - .send(&self.socket) - { - warn!("Could not emit kernel's state. {}", err) - } + log::trace!("Entering kernel state: {:?}", state); + + let content = KernelStatus { + execution_state: state, + }; + + let message = Message::Status(self.message(content)); + + if let Err(err) = self.forward(message) { + log::warn!("Could not emit kernel's state due to: {err:?}") + }; } } diff --git a/crates/amalthea/src/socket/socket.rs b/crates/amalthea/src/socket/socket.rs index 5017bb8b8..0ffb812b8 100644 --- a/crates/amalthea/src/socket/socket.rs +++ b/crates/amalthea/src/socket/socket.rs @@ -5,8 +5,6 @@ * */ -use log::trace; - use crate::error::Error; use crate::session::Session; @@ -37,19 +35,57 @@ impl Socket { ) -> Result { let socket = Self::new_raw(ctx, name.clone(), kind, identity)?; + // For the server side of IOPub, there are a few options we need to tweak + if name == "IOPub" && kind == zmq::SocketType::XPUB { + // Sets the XPUB socket to report subscription events even for + // topics that were already subscribed to. + // + // See notes in https://zguide.zeromq.org/docs/chapter5 and + // https://zguide.zeromq.org/docs/chapter6 and the discussion in + // https://lists.zeromq.org/pipermail/zeromq-dev/2012-October/018470.html + // that lead to the creation of this socket option. + socket + .set_xpub_verbose(true) + .map_err(|err| Error::CreateSocketFailed(name.clone(), err))?; + + // For IOPub in particular, which is fairly high traffic, we up the + // "high water mark" from the default of 1k -> 100k to avoid dropping + // messages if the subscriber is processing them too slowly. This has + // to be set before the call to `bind()`. It seems like we could + // alternatively set the rcvhwm on the subscriber side, since the + // "total" sndhmw seems to be the sum of the pub + sub values, but this + // is probably best to tell any subscribers out there that this is a + // high traffic channel. + // https://github.com/posit-dev/amalthea/pull/129 + socket + .set_sndhwm(100000) + .map_err(|err| Error::CreateSocketFailed(name.clone(), err))?; + } + + // If this is a debug build, set `ZMQ_ROUTER_MANDATORY` on all `ROUTER` + // sockets, so that we get errors instead of silent message drops for + // unroutable messages. + #[cfg(debug_assertions)] + { + if kind == zmq::ROUTER { + if let Err(err) = socket.set_router_mandatory(true) { + return Err(Error::SocketBindError(name, endpoint, err)); + } + } + } + // One side of a socket must `bind()` to its endpoint, and the other // side must `connect()` to the same endpoint. The `bind()` side // will be the server, and the `connect()` side will be the client. match kind { - zmq::SocketType::ROUTER | zmq::SocketType::PUB | zmq::SocketType::REP => { - trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint); + zmq::SocketType::ROUTER | zmq::SocketType::XPUB | zmq::SocketType::REP => { + log::trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint); if let Err(err) = socket.bind(&endpoint) { return Err(Error::SocketBindError(name, endpoint, err)); } }, zmq::SocketType::DEALER | zmq::SocketType::SUB | zmq::SocketType::REQ => { - // Bind the socket to the requested endpoint - trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint); + log::trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint); if let Err(err) = socket.connect(&endpoint) { return Err(Error::SocketConnectError(name, endpoint, err)); } @@ -57,19 +93,7 @@ impl Socket { _ => return Err(Error::UnsupportedSocketType(kind)), } - // If this is a debug build, set `ZMQ_ROUTER_MANDATORY` on all `ROUTER` - // sockets, so that we get errors instead of silent message drops for - // unroutable messages. - #[cfg(debug_assertions)] - { - if kind == zmq::ROUTER { - if let Err(err) = socket.set_router_mandatory(true) { - return Err(Error::SocketBindError(name, endpoint, err)); - } - } - } - - // Create a new mutex and return + // Create a new socket and return Ok(Self { socket, session, @@ -88,12 +112,12 @@ impl Socket { let socket = Self::new_raw(ctx, name.clone(), zmq::PAIR, identity)?; if bind { - trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint); + log::trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint); if let Err(err) = socket.bind(&endpoint) { return Err(Error::SocketBindError(name, endpoint, err)); } } else { - trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint); + log::trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint); if let Err(err) = socket.connect(&endpoint) { return Err(Error::SocketConnectError(name, endpoint, err)); } @@ -118,21 +142,6 @@ impl Socket { Err(err) => return Err(Error::CreateSocketFailed(name, err)), }; - // For IOPub in particular, which is fairly high traffic, we up the - // "high water mark" from the default of 1k -> 100k to avoid dropping - // messages if the subscriber is processing them too slowly. This has - // to be set before the call to `bind()`. It seems like we could - // alternatively set the rcvhwm on the subscriber side, since the - // "total" sndhmw seems to be the sum of the pub + sub values, but this - // is probably best to tell any subscribers out there that this is a - // high traffic channel. - // https://github.com/posit-dev/amalthea/pull/129 - if name == "IOPub" { - if let Err(error) = socket.set_sndhwm(100000) { - return Err(Error::CreateSocketFailed(name, error)); - } - } - // Set the socket's identity, if supplied if let Some(identity) = identity { if let Err(err) = socket.set_identity(identity) { @@ -188,15 +197,29 @@ impl Socket { self.poll_incoming(0) } - /// Subscribes a SUB socket to all the published messages from a PUB socket. + /// Subscribes a SUB socket to messages from an XPUB socket. + /// + /// Use `b""` to subscribe to all messages. /// /// Note that this needs to be called *after* the socket connection is /// established on both ends. - pub fn subscribe(&self) -> Result<(), Error> { + pub fn subscribe(&self, subscription: &[u8]) -> Result<(), Error> { + let socket_type = match self.socket.get_socket_type() { + Ok(socket_type) => socket_type, + Err(err) => return Err(Error::ZmqError(self.name.clone(), err)), + }; + + if socket_type != zmq::SocketType::SUB { + return Err(crate::anyhow!( + "Can't subscribe on a non-SUB socket. This socket is a {socket_type:?}." + )); + } + // Currently, all SUB sockets subscribe to all topics; in theory // frontends could subscribe selectively, but in practice all known - // Jupyter frontends subscribe to all topics. - match self.socket.set_subscribe(b"") { + // Jupyter frontends subscribe to all topics and just ignore topics + // they don't recognize. + match self.socket.set_subscribe(subscription) { Ok(_) => Ok(()), Err(err) => Err(Error::ZmqError(self.name.clone(), err)), } diff --git a/crates/amalthea/src/wire/jupyter_message.rs b/crates/amalthea/src/wire/jupyter_message.rs index 53e7b3e4c..dd71785b3 100644 --- a/crates/amalthea/src/wire/jupyter_message.rs +++ b/crates/amalthea/src/wire/jupyter_message.rs @@ -8,9 +8,12 @@ use serde::Deserialize; use serde::Serialize; +use super::display_data::DisplayData; use super::handshake_reply::HandshakeReply; use super::handshake_request::HandshakeRequest; use super::stream::StreamOutput; +use super::update_display_data::UpdateDisplayData; +use super::welcome::Welcome; use crate::comm::base_comm::JsonRpcReply; use crate::comm::ui_comm::UiFrontendRequest; use crate::error::Error; @@ -77,41 +80,54 @@ impl ProtocolMessage for T where T: MessageType + Serialize + std::fmt::Debug /// List of all known/implemented messages #[derive(Debug)] pub enum Message { + // Shell + KernelInfoReply(JupyterMessage), + KernelInfoRequest(JupyterMessage), CompleteReply(JupyterMessage), CompleteRequest(JupyterMessage), ExecuteReply(JupyterMessage), ExecuteReplyException(JupyterMessage), ExecuteRequest(JupyterMessage), - ExecuteResult(JupyterMessage), - ExecuteError(JupyterMessage), - ExecuteInput(JupyterMessage), - InputReply(JupyterMessage), - InputRequest(JupyterMessage), InspectReply(JupyterMessage), InspectRequest(JupyterMessage), - InterruptReply(JupyterMessage), - InterruptRequest(JupyterMessage), IsCompleteReply(JupyterMessage), IsCompleteRequest(JupyterMessage), - KernelInfoReply(JupyterMessage), - KernelInfoRequest(JupyterMessage), - ShutdownRequest(JupyterMessage), - Status(JupyterMessage), CommInfoReply(JupyterMessage), CommInfoRequest(JupyterMessage), - CommOpen(JupyterMessage), - CommMsg(JupyterMessage), CommRequest(JupyterMessage), CommReply(JupyterMessage), - CommClose(JupyterMessage), - StreamOutput(JupyterMessage), + InputReply(JupyterMessage), + InputRequest(JupyterMessage), + // Control + InterruptReply(JupyterMessage), + InterruptRequest(JupyterMessage), + ShutdownRequest(JupyterMessage), + // Registration HandshakeRequest(JupyterMessage), HandshakeReply(JupyterMessage), + // IOPub + Status(JupyterMessage), + ExecuteResult(JupyterMessage), + ExecuteError(JupyterMessage), + ExecuteInput(JupyterMessage), + Stream(JupyterMessage), + DisplayData(JupyterMessage), + UpdateDisplayData(JupyterMessage), + Welcome(JupyterMessage), + // IOPub/Shell + CommMsg(JupyterMessage), + CommOpen(JupyterMessage), + CommClose(JupyterMessage), } -/// Associates a `Message` to a 0MQ socket +/// Associates a `Message` to a 0MQ socket. +/// +/// At a high level, outbound messages originate from kernel components on a +/// crossbeam channel and are transfered to the client via a 0MQ socket owned by +/// the forwarding thread. pub enum OutboundMessage { StdIn(Message), + IOPub(Message), } /// Represents status returned from kernel inside messages. @@ -156,9 +172,12 @@ impl TryFrom<&Message> for WireMessage { Message::CommClose(msg) => WireMessage::try_from(msg), Message::CommRequest(msg) => WireMessage::try_from(msg), Message::CommReply(msg) => WireMessage::try_from(msg), - Message::StreamOutput(msg) => WireMessage::try_from(msg), + Message::Stream(msg) => WireMessage::try_from(msg), Message::HandshakeReply(msg) => WireMessage::try_from(msg), Message::HandshakeRequest(msg) => WireMessage::try_from(msg), + Message::DisplayData(msg) => WireMessage::try_from(msg), + Message::UpdateDisplayData(msg) => WireMessage::try_from(msg), + Message::Welcome(msg) => WireMessage::try_from(msg), } } } @@ -254,7 +273,7 @@ impl TryFrom<&WireMessage> for Message { return Ok(Message::InputRequest(JupyterMessage::try_from(msg)?)); } if kind == StreamOutput::message_type() { - return Ok(Message::StreamOutput(JupyterMessage::try_from(msg)?)); + return Ok(Message::Stream(JupyterMessage::try_from(msg)?)); } if kind == UiFrontendRequest::message_type() { return Ok(Message::CommRequest(JupyterMessage::try_from(msg)?)); @@ -268,6 +287,9 @@ impl TryFrom<&WireMessage> for Message { if kind == HandshakeReply::message_type() { return Ok(Message::HandshakeReply(JupyterMessage::try_from(msg)?)); } + if kind == Welcome::message_type() { + return Ok(Message::Welcome(JupyterMessage::try_from(msg)?)); + } return Err(Error::UnknownMessageType(kind)); } } diff --git a/crates/amalthea/src/wire/mod.rs b/crates/amalthea/src/wire/mod.rs index 7e8c48489..5c1e4b451 100644 --- a/crates/amalthea/src/wire/mod.rs +++ b/crates/amalthea/src/wire/mod.rs @@ -43,5 +43,7 @@ pub mod shutdown_reply; pub mod shutdown_request; pub mod status; pub mod stream; +pub mod subscription_message; pub mod update_display_data; +pub mod welcome; pub mod wire_message; diff --git a/crates/amalthea/src/wire/subscription_message.rs b/crates/amalthea/src/wire/subscription_message.rs new file mode 100644 index 000000000..e84b33566 --- /dev/null +++ b/crates/amalthea/src/wire/subscription_message.rs @@ -0,0 +1,81 @@ +/* + * subscription_message.rs + * + * Copyright (C) 2024 Posit Software, PBC. All rights reserved. + * + */ + +use serde::Deserialize; +use serde::Serialize; + +use crate::error::Error; +use crate::socket::socket::Socket; + +/// Represents a special `SubscriptionMessage` sent from a SUB to an XPUB +/// upon `socket.set_subscribe(subscription)` or `socket.set_unsubscribe(subscription)`. +#[derive(Debug, Serialize, Deserialize)] +pub struct SubscriptionMessage { + pub kind: SubscriptionKind, + pub subscription: String, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub enum SubscriptionKind { + Subscribe, + Unsubscribe, +} + +impl SubscriptionMessage { + /// Read a SubscriptionMessage from a ZeroMQ socket. + pub fn read_from_socket(socket: &Socket) -> crate::Result { + let bufs = socket.recv_multipart()?; + Self::from_buffers(bufs) + } + + /// Parse a SubscriptionMessage from an array of buffers (from a ZeroMQ message) + /// + /// Always a single frame (i.e. `bufs` should be length 1). + /// Either `1{subscription}` for subscription. + /// Or `0{subscription}` for unsubscription. + fn from_buffers(bufs: Vec>) -> crate::Result { + if bufs.len() != 1 { + let n = bufs.len(); + return Err(crate::anyhow!( + "Subscription message on XPUB must be a single frame. {n} frames were received." + )); + } + + let buf = bufs.get(0).unwrap(); + + if buf.len() == 0 { + return Err(crate::anyhow!( + "Subscription message on XPUB must be at least length 1 to determine subscribe/unsubscribe." + )); + } + + let kind = if buf[0] == 1 { + SubscriptionKind::Subscribe + } else { + SubscriptionKind::Unsubscribe + }; + + // Advance to access remaining buffer + let buf = &buf[1..]; + + // The rest of the message is the UTF-8 `subscription` + let subscription = match std::str::from_utf8(&buf) { + Ok(subscription) => subscription, + Err(err) => { + return Err(Error::Utf8Error( + String::from("subscription"), + buf.to_vec(), + err, + )) + }, + }; + + let subscription = subscription.to_string(); + + Ok(Self { kind, subscription }) + } +} diff --git a/crates/amalthea/src/wire/welcome.rs b/crates/amalthea/src/wire/welcome.rs new file mode 100644 index 000000000..ef76381a2 --- /dev/null +++ b/crates/amalthea/src/wire/welcome.rs @@ -0,0 +1,35 @@ +/* + * welcome.rs + * + * Copyright (C) 2023-2024 Posit Software, PBC. All rights reserved. + * + */ + +use serde::Deserialize; +use serde::Serialize; + +use crate::wire::jupyter_message::MessageType; + +/// An IOPub message used for handshaking by modern clients. +/// See JEP 65: https://github.com/jupyter/enhancement-proposals/pull/65 +/// +/// Note that this IOPub `Welcome` message is the same basic idea as +/// `ZMQ_XPUB_WELCOME_MSG`, set through `socket.set_xpub_welcome_msg()`, +/// but the JEP committee decided not to use that. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Welcome { + /// The `subscription` sent to the XPUB socket by the SUB's call + /// to `socket.set_subscribe(subscription)`. The IOPub XPUB socket + /// passes this `subscription` back to the IOPub SUB in the `Welcome` + /// message. + pub subscription: String, +} + +// Message type comes from copying what xeus and jupyter_kernel_test use: +// https://github.com/jupyter-xeus/xeus-zmq/pull/31 +// https://github.com/jupyter/jupyter_kernel_test/blob/5f2c65271b48dc95fc75a9585cb1d6db0bb55557/jupyter_kernel_test/__init__.py#L449-L450 +impl MessageType for Welcome { + fn message_type() -> String { + String::from("iopub_welcome") + } +}