diff --git a/Cargo.lock b/Cargo.lock index d8729dfe8..df11212ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,7 @@ name = "amalthea" version = "0.1.0" dependencies = [ "amalthea-macros", + "anyhow", "async-trait", "chrono", "crossbeam", diff --git a/crates/amalthea/Cargo.toml b/crates/amalthea/Cargo.toml index 646991ceb..68c87526d 100644 --- a/crates/amalthea/Cargo.toml +++ b/crates/amalthea/Cargo.toml @@ -25,6 +25,7 @@ zmq = "0.10.0" strum = "0.24" strum_macros = "0.24" crossbeam = { version = "0.8.2", features = ["crossbeam-channel"] } +anyhow = "1.0.71" [dev-dependencies] rand = "0.8.5" diff --git a/crates/amalthea/src/error.rs b/crates/amalthea/src/error.rs index 1547f0656..c52bb16c1 100644 --- a/crates/amalthea/src/error.rs +++ b/crates/amalthea/src/error.rs @@ -44,6 +44,8 @@ pub enum Error { InvalidCommMessage(String, String, String), } +impl std::error::Error for Error {} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/crates/amalthea/src/kernel.rs b/crates/amalthea/src/kernel.rs index 81067272b..558decd6d 100644 --- a/crates/amalthea/src/kernel.rs +++ b/crates/amalthea/src/kernel.rs @@ -9,10 +9,14 @@ use std::sync::Arc; use std::sync::Mutex; use crossbeam::channel::bounded; +use crossbeam::channel::unbounded; use crossbeam::channel::Receiver; +use crossbeam::channel::Select; use crossbeam::channel::Sender; +use log::error; use log::info; use stdext::spawn; +use stdext::unwrap; use crate::comm::comm_manager::CommManager; use crate::comm::event::CommChanged; @@ -33,6 +37,8 @@ use crate::socket::stdin::Stdin; use crate::stream_capture::StreamCapture; use crate::wire::header::JupyterHeader; use crate::wire::input_request::ShellInputRequest; +use crate::wire::jupyter_message::Message; +use crate::wire::jupyter_message::OutboundMessage; /// A Kernel represents a unique Jupyter kernel session and is the host for all /// execution and messaging threads. @@ -116,6 +122,10 @@ impl Kernel { ) -> Result<(), Error> { let ctx = zmq::Context::new(); + // Channels for communication of outbound messages between the + // socket threads and the 0MQ forwarding thread + let (outbound_tx, outbound_rx) = unbounded(); + // Create the comm manager thread let iopub_tx = self.create_iopub_tx(); let comm_manager_rx = self.comm_manager_rx.clone(); @@ -191,8 +201,19 @@ impl Kernel { )?; let shell_clone = shell_handler.clone(); let msg_context = self.msg_context.clone(); + + let (stdin_inbound_tx, stdin_inbound_rx) = unbounded(); + let stdin_session = stdin_socket.session.clone(); + spawn!(format!("{}-stdin", self.name), move || { - Self::stdin_thread(stdin_socket, shell_clone, msg_context, input_request_rx) + Self::stdin_thread( + stdin_inbound_rx, + outbound_tx, + shell_clone, + msg_context, + input_request_rx, + stdin_session, + ) }); // Create the thread that handles stdout and stderr, if requested @@ -213,6 +234,45 @@ impl Kernel { self.connection.endpoint(self.connection.control_port), )?; + // Internal sockets for notifying the 0MQ forwarding + // thread that new outbound messages are available + let outbound_notif_socket_tx = Socket::new_pair( + self.session.clone(), + ctx.clone(), + String::from("OutboundNotifierTx"), + None, + String::from("inproc://outbound_notif"), + true, + )?; + let outbound_notif_socket_rx = Socket::new_pair( + self.session.clone(), + ctx.clone(), + String::from("OutboundNotifierRx"), + None, + String::from("inproc://outbound_notif"), + false, + )?; + + let outbound_rx_clone = outbound_rx.clone(); + + // Forwarding thread that bridges 0MQ sockets and Amalthea + // channels. Currently only used by StdIn. + spawn!(format!("{}-zmq-forwarding", self.name), move || { + Self::zmq_forwarding_thread( + outbound_notif_socket_rx, + stdin_socket, + stdin_inbound_tx, + outbound_rx_clone, + ) + }); + + // The notifier thread watches Amalthea channels of outgoing + // messages for readiness. When a channel is hot, it notifies the + // forwarding thread through a 0MQ socket. + spawn!(format!("{}-zmq-notifier", self.name), move || { + Self::zmq_notifier_thread(outbound_notif_socket_tx, outbound_rx) + }); + // 0MQ sockets are now initialised. We can start the kernel runtime // with relative multithreading safety. See // https://github.com/rstudio/positron/issues/720 @@ -285,16 +345,148 @@ impl Kernel { /// Starts the stdin thread. fn stdin_thread( - socket: Socket, + inbound_rx: Receiver, + outbound_tx: Sender, shell_handler: Arc>, msg_context: Arc>>, input_request_rx: Receiver, + session: Session, ) -> Result<(), Error> { - let stdin = Stdin::new(socket, shell_handler, msg_context); + let stdin = Stdin::new(inbound_rx, outbound_tx, shell_handler, msg_context, session); stdin.listen(input_request_rx); Ok(()) } + /// Starts the thread that forwards 0MQ messages to Amalthea channels + /// and vice versa. + fn zmq_forwarding_thread( + outbound_notif_socket: Socket, + stdin_socket: Socket, + stdin_inbound_tx: Sender, + outbound_rx: Receiver, + ) { + // This function checks for notifications that an outgoing message + // is ready to be read on an Amalthea channel. It returns + // immediately whether a message is ready or not. + let has_outbound = || -> bool { + if let Ok(n) = outbound_notif_socket.socket.poll(zmq::POLLIN, 0) { + if n == 0 { + return false; + } + // Consume notification + let _ = unwrap!(outbound_notif_socket.socket.recv_bytes(0), Err(err) => { + log::error!("Could not consume outbound notification socket: {}", err); + return false; + }); + + true + } else { + false + } + }; + + // This function checks that a 0MQ message from the frontend is ready. + let has_inbound = || -> bool { + match stdin_socket.socket.poll(zmq::POLLIN, 0) { + Ok(n) if n > 0 => true, + _ => false, + } + }; + + // Forwards channel message from Amalthea to the frontend via the + // corresponding 0MQ socket. Should consume exactly 1 message and + // notify back the notifier thread to keep the mechanism synchronised. + let forward_outbound = || -> anyhow::Result<()> { + // Consume message and forward it + let outbound_msg = outbound_rx.recv()?; + match outbound_msg { + OutboundMessage::StdIn(msg) => msg.send(&stdin_socket)?, + }; + + // Notify back + outbound_notif_socket.send(zmq::Message::new())?; + + Ok(()) + }; + + // Forwards 0MQ message from the frontend to the corresponding + // Amalthea channel. + let forward_inbound = || -> anyhow::Result<()> { + let msg = Message::read_from_socket(&stdin_socket)?; + stdin_inbound_tx.send(msg)?; + Ok(()) + }; + + // Create poll items necessary to call `zmq_poll()` + let mut poll_items = { + let outbound_notif_poll_item = outbound_notif_socket.socket.as_poll_item(zmq::POLLIN); + let stdin_poll_item = stdin_socket.socket.as_poll_item(zmq::POLLIN); + vec![outbound_notif_poll_item, stdin_poll_item] + }; + + loop { + let n = unwrap!( + zmq::poll(&mut poll_items, -1), + Err(err) => { + error!("While polling 0MQ items: {}", err); + 0 + } + ); + + for _ in 0..n { + if has_outbound() { + unwrap!( + forward_outbound(), + Err(err) => error!("While forwarding outbound message: {}", err) + ); + continue; + } + + if has_inbound() { + unwrap!( + forward_inbound(), + Err(err) => error!("While forwarding inbound message: {}", err) + ); + continue; + } + + log::error!("Could not find readable message"); + } + } + } + + /// Starts the thread that notifies the forwarding thread that new + /// outgoing messages have arrived from Amalthea. + fn zmq_notifier_thread(notif_socket: Socket, outbound_rx: Receiver) { + let mut sel = Select::new(); + sel.recv(&outbound_rx); + + loop { + let _ = sel.ready(); + + unwrap!( + notif_socket.send(zmq::Message::new()), + Err(err) => { + error!("Couldn't notify 0MQ thread: {}", err); + continue; + } + ); + + // To keep things synchronised, wait to be notified that the + // channel message has been consumed before continuing the loop. + unwrap!( + { + let mut msg = zmq::Message::new(); + notif_socket.recv(&mut msg) + }, + Err(err) => { + error!("Couldn't received acknowledgement from 0MQ thread: {}", err); + continue; + } + ); + } + } + /// Starts the output capture thread. fn output_capture_thread(iopub_tx: Sender) -> Result<(), Error> { let output_capture = StreamCapture::new(iopub_tx); diff --git a/crates/amalthea/src/socket/socket.rs b/crates/amalthea/src/socket/socket.rs index 03fec3423..c412101a2 100644 --- a/crates/amalthea/src/socket/socket.rs +++ b/crates/amalthea/src/socket/socket.rs @@ -22,7 +22,7 @@ pub struct Socket { pub name: String, /// A ZeroMQ socket over which signed messages are to be sent/received - socket: zmq::Socket, + pub socket: zmq::Socket, } impl Socket { @@ -35,18 +35,7 @@ impl Socket { identity: Option<&[u8]>, endpoint: String, ) -> Result { - // Create the underlying ZeroMQ socket - let socket = match ctx.socket(kind) { - Ok(s) => s, - Err(err) => return Err(Error::CreateSocketFailed(name, err)), - }; - - // Set the socket's identity, if supplied - if let Some(identity) = identity { - if let Err(err) = socket.set_identity(identity) { - return Err(Error::CreateSocketFailed(name, err)); - } - } + let socket = Self::new_raw(ctx, name.clone(), kind, identity)?; // One side of a socket must `bind()` to its endpoint, and the other // side must `connect()` to the same endpoint. The `bind()` side @@ -88,6 +77,57 @@ impl Socket { }) } + pub fn new_pair( + session: Session, + ctx: zmq::Context, + name: String, + identity: Option<&[u8]>, + endpoint: String, + bind: bool, + ) -> Result { + let socket = Self::new_raw(ctx, name.clone(), zmq::PAIR, identity)?; + + if bind { + trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint); + if let Err(err) = socket.bind(&endpoint) { + return Err(Error::SocketBindError(name, endpoint, err)); + } + } else { + trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint); + if let Err(err) = socket.connect(&endpoint) { + return Err(Error::SocketConnectError(name, endpoint, err)); + } + } + + Ok(Self { + socket, + session, + name, + }) + } + + fn new_raw( + ctx: zmq::Context, + name: String, + kind: zmq::SocketType, + identity: Option<&[u8]>, + ) -> Result { + // Create the underlying ZeroMQ socket + let socket = match ctx.socket(kind) { + Ok(s) => s, + Err(err) => return Err(Error::CreateSocketFailed(name, err)), + }; + + // Set the socket's identity, if supplied + if let Some(identity) = identity { + if let Err(err) = socket.set_identity(identity) { + return Err(Error::CreateSocketFailed(name, err)); + } + } + + Ok(socket) + } + /// Receive a message from the socket. /// /// **Note**: This will block until a message is delivered on the socket. diff --git a/crates/amalthea/src/socket/stdin.rs b/crates/amalthea/src/socket/stdin.rs index 182b90a05..5fa8160f9 100644 --- a/crates/amalthea/src/socket/stdin.rs +++ b/crates/amalthea/src/socket/stdin.rs @@ -9,21 +9,26 @@ use std::sync::Arc; use std::sync::Mutex; use crossbeam::channel::Receiver; +use crossbeam::channel::Sender; use futures::executor::block_on; use log::trace; use log::warn; use crate::language::shell_handler::ShellHandler; -use crate::socket::socket::Socket; +use crate::session::Session; use crate::wire::header::JupyterHeader; use crate::wire::input_request::ShellInputRequest; use crate::wire::jupyter_message::JupyterMessage; use crate::wire::jupyter_message::Message; +use crate::wire::jupyter_message::OutboundMessage; use crate::wire::originator::Originator; pub struct Stdin { - /// The ZeroMQ stdin socket - socket: Socket, + /// Receiver connected to the StdIn's ZeroMQ socket + inbound_rx: Receiver, + + /// Sender connected to the StdIn's ZeroMQ socket + outbound_tx: Sender, /// Language-provided shell handler object handler: Arc>, @@ -31,6 +36,9 @@ pub struct Stdin { // IOPub message context. Updated from StdIn on input replies so that new // output gets attached to the correct input element in the console. msg_context: Arc>>, + + // 0MQ session, needed to create `JupyterMessage` objects + session: Session, } impl Stdin { @@ -40,14 +48,18 @@ impl Stdin { /// * `handler` - The language's shell handler /// * `msg_context` - The IOPub message context pub fn new( - socket: Socket, + inbound_rx: Receiver, + outbound_tx: Sender, handler: Arc>, msg_context: Arc>>, + session: Session, ) -> Self { Self { - socket, + inbound_rx, + outbound_tx, handler, msg_context, + session, } } @@ -65,24 +77,20 @@ impl Stdin { } // Deliver the message to the front end - let msg = JupyterMessage::create_with_identity( + let msg = Message::InputRequest(JupyterMessage::create_with_identity( req.originator, req.request, - &self.socket.session, - ); - if let Err(err) = msg.send(&self.socket) { - warn!("Failed to send message to front end: {}", err); + &self.session, + )); + + if let Err(_) = self.outbound_tx.send(OutboundMessage::StdIn(msg)) { + warn!("Failed to send message to front end"); } trace!("Sent input request to front end, waiting for input reply..."); - // Attempt to read the front end's reply message from the ZeroMQ socket. - // - // TODO: This will block until the front end sends an input request, - // which could be a while and perhaps never if the user cancels the - // operation, never provides input, etc. We should probably have a - // timeout here, or some way to cancel the read if another input - // request arrives. - let message = match Message::read_from_socket(&self.socket) { + // Wait for the front end's reply message from the ZeroMQ socket. + // TODO: Wait for interrupts via another channel. + let message = match self.inbound_rx.recv() { Ok(m) => m, Err(err) => { warn!("Could not read message from stdin socket: {}", err); diff --git a/crates/amalthea/src/wire/jupyter_message.rs b/crates/amalthea/src/wire/jupyter_message.rs index ef5d5a8dd..00e3490ec 100644 --- a/crates/amalthea/src/wire/jupyter_message.rs +++ b/crates/amalthea/src/wire/jupyter_message.rs @@ -101,6 +101,11 @@ pub enum Message { ClientEvent(JupyterMessage), } +/// Associates a `Message` to a 0MQ socket +pub enum OutboundMessage { + StdIn(Message), +} + /// Represents status returned from kernel inside messages. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")]