Skip to content

Commit

Permalink
Merge pull request #58 from posit-dev/refactor/stdin-channels
Browse files Browse the repository at this point in the history
Forward 0MQ messages for StdIn over crossbeam channel
  • Loading branch information
lionel- authored Jun 30, 2023
2 parents 1d8b4e1 + 1a0cafc commit 4be01aa
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/amalthea/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ zmq = "0.10.0"
strum = "0.24"
strum_macros = "0.24"
crossbeam = { version = "0.8.2", features = ["crossbeam-channel"] }
anyhow = "1.0.71"

[dev-dependencies]
rand = "0.8.5"
Expand Down
2 changes: 2 additions & 0 deletions crates/amalthea/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub enum Error {
InvalidCommMessage(String, String, String),
}

impl std::error::Error for Error {}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expand Down
198 changes: 195 additions & 3 deletions crates/amalthea/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ use std::sync::Arc;
use std::sync::Mutex;

use crossbeam::channel::bounded;
use crossbeam::channel::unbounded;
use crossbeam::channel::Receiver;
use crossbeam::channel::Select;
use crossbeam::channel::Sender;
use log::error;
use log::info;
use stdext::spawn;
use stdext::unwrap;

use crate::comm::comm_manager::CommManager;
use crate::comm::event::CommChanged;
Expand All @@ -33,6 +37,8 @@ use crate::socket::stdin::Stdin;
use crate::stream_capture::StreamCapture;
use crate::wire::header::JupyterHeader;
use crate::wire::input_request::ShellInputRequest;
use crate::wire::jupyter_message::Message;
use crate::wire::jupyter_message::OutboundMessage;

/// A Kernel represents a unique Jupyter kernel session and is the host for all
/// execution and messaging threads.
Expand Down Expand Up @@ -116,6 +122,10 @@ impl Kernel {
) -> Result<(), Error> {
let ctx = zmq::Context::new();

// Channels for communication of outbound messages between the
// socket threads and the 0MQ forwarding thread
let (outbound_tx, outbound_rx) = unbounded();

// Create the comm manager thread
let iopub_tx = self.create_iopub_tx();
let comm_manager_rx = self.comm_manager_rx.clone();
Expand Down Expand Up @@ -191,8 +201,19 @@ impl Kernel {
)?;
let shell_clone = shell_handler.clone();
let msg_context = self.msg_context.clone();

let (stdin_inbound_tx, stdin_inbound_rx) = unbounded();
let stdin_session = stdin_socket.session.clone();

spawn!(format!("{}-stdin", self.name), move || {
Self::stdin_thread(stdin_socket, shell_clone, msg_context, input_request_rx)
Self::stdin_thread(
stdin_inbound_rx,
outbound_tx,
shell_clone,
msg_context,
input_request_rx,
stdin_session,
)
});

// Create the thread that handles stdout and stderr, if requested
Expand All @@ -213,6 +234,45 @@ impl Kernel {
self.connection.endpoint(self.connection.control_port),
)?;

// Internal sockets for notifying the 0MQ forwarding
// thread that new outbound messages are available
let outbound_notif_socket_tx = Socket::new_pair(
self.session.clone(),
ctx.clone(),
String::from("OutboundNotifierTx"),
None,
String::from("inproc://outbound_notif"),
true,
)?;
let outbound_notif_socket_rx = Socket::new_pair(
self.session.clone(),
ctx.clone(),
String::from("OutboundNotifierRx"),
None,
String::from("inproc://outbound_notif"),
false,
)?;

let outbound_rx_clone = outbound_rx.clone();

// Forwarding thread that bridges 0MQ sockets and Amalthea
// channels. Currently only used by StdIn.
spawn!(format!("{}-zmq-forwarding", self.name), move || {
Self::zmq_forwarding_thread(
outbound_notif_socket_rx,
stdin_socket,
stdin_inbound_tx,
outbound_rx_clone,
)
});

// The notifier thread watches Amalthea channels of outgoing
// messages for readiness. When a channel is hot, it notifies the
// forwarding thread through a 0MQ socket.
spawn!(format!("{}-zmq-notifier", self.name), move || {
Self::zmq_notifier_thread(outbound_notif_socket_tx, outbound_rx)
});

// 0MQ sockets are now initialised. We can start the kernel runtime
// with relative multithreading safety. See
// https://github.com/rstudio/positron/issues/720
Expand Down Expand Up @@ -285,16 +345,148 @@ impl Kernel {

/// Starts the stdin thread.
fn stdin_thread(
socket: Socket,
inbound_rx: Receiver<Message>,
outbound_tx: Sender<OutboundMessage>,
shell_handler: Arc<Mutex<dyn ShellHandler>>,
msg_context: Arc<Mutex<Option<JupyterHeader>>>,
input_request_rx: Receiver<ShellInputRequest>,
session: Session,
) -> Result<(), Error> {
let stdin = Stdin::new(socket, shell_handler, msg_context);
let stdin = Stdin::new(inbound_rx, outbound_tx, shell_handler, msg_context, session);
stdin.listen(input_request_rx);
Ok(())
}

/// Starts the thread that forwards 0MQ messages to Amalthea channels
/// and vice versa.
fn zmq_forwarding_thread(
outbound_notif_socket: Socket,
stdin_socket: Socket,
stdin_inbound_tx: Sender<Message>,
outbound_rx: Receiver<OutboundMessage>,
) {
// This function checks for notifications that an outgoing message
// is ready to be read on an Amalthea channel. It returns
// immediately whether a message is ready or not.
let has_outbound = || -> bool {
if let Ok(n) = outbound_notif_socket.socket.poll(zmq::POLLIN, 0) {
if n == 0 {
return false;
}
// Consume notification
let _ = unwrap!(outbound_notif_socket.socket.recv_bytes(0), Err(err) => {
log::error!("Could not consume outbound notification socket: {}", err);
return false;
});

true
} else {
false
}
};

// This function checks that a 0MQ message from the frontend is ready.
let has_inbound = || -> bool {
match stdin_socket.socket.poll(zmq::POLLIN, 0) {
Ok(n) if n > 0 => true,
_ => false,
}
};

// Forwards channel message from Amalthea to the frontend via the
// corresponding 0MQ socket. Should consume exactly 1 message and
// notify back the notifier thread to keep the mechanism synchronised.
let forward_outbound = || -> anyhow::Result<()> {
// Consume message and forward it
let outbound_msg = outbound_rx.recv()?;
match outbound_msg {
OutboundMessage::StdIn(msg) => msg.send(&stdin_socket)?,
};

// Notify back
outbound_notif_socket.send(zmq::Message::new())?;

Ok(())
};

// Forwards 0MQ message from the frontend to the corresponding
// Amalthea channel.
let forward_inbound = || -> anyhow::Result<()> {
let msg = Message::read_from_socket(&stdin_socket)?;
stdin_inbound_tx.send(msg)?;
Ok(())
};

// Create poll items necessary to call `zmq_poll()`
let mut poll_items = {
let outbound_notif_poll_item = outbound_notif_socket.socket.as_poll_item(zmq::POLLIN);
let stdin_poll_item = stdin_socket.socket.as_poll_item(zmq::POLLIN);
vec![outbound_notif_poll_item, stdin_poll_item]
};

loop {
let n = unwrap!(
zmq::poll(&mut poll_items, -1),
Err(err) => {
error!("While polling 0MQ items: {}", err);
0
}
);

for _ in 0..n {
if has_outbound() {
unwrap!(
forward_outbound(),
Err(err) => error!("While forwarding outbound message: {}", err)
);
continue;
}

if has_inbound() {
unwrap!(
forward_inbound(),
Err(err) => error!("While forwarding inbound message: {}", err)
);
continue;
}

log::error!("Could not find readable message");
}
}
}

/// Starts the thread that notifies the forwarding thread that new
/// outgoing messages have arrived from Amalthea.
fn zmq_notifier_thread(notif_socket: Socket, outbound_rx: Receiver<OutboundMessage>) {
let mut sel = Select::new();
sel.recv(&outbound_rx);

loop {
let _ = sel.ready();

unwrap!(
notif_socket.send(zmq::Message::new()),
Err(err) => {
error!("Couldn't notify 0MQ thread: {}", err);
continue;
}
);

// To keep things synchronised, wait to be notified that the
// channel message has been consumed before continuing the loop.
unwrap!(
{
let mut msg = zmq::Message::new();
notif_socket.recv(&mut msg)
},
Err(err) => {
error!("Couldn't received acknowledgement from 0MQ thread: {}", err);
continue;
}
);
}
}

/// Starts the output capture thread.
fn output_capture_thread(iopub_tx: Sender<IOPubMessage>) -> Result<(), Error> {
let output_capture = StreamCapture::new(iopub_tx);
Expand Down
66 changes: 53 additions & 13 deletions crates/amalthea/src/socket/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct Socket {
pub name: String,

/// A ZeroMQ socket over which signed messages are to be sent/received
socket: zmq::Socket,
pub socket: zmq::Socket,
}

impl Socket {
Expand All @@ -35,18 +35,7 @@ impl Socket {
identity: Option<&[u8]>,
endpoint: String,
) -> Result<Self, Error> {
// Create the underlying ZeroMQ socket
let socket = match ctx.socket(kind) {
Ok(s) => s,
Err(err) => return Err(Error::CreateSocketFailed(name, err)),
};

// Set the socket's identity, if supplied
if let Some(identity) = identity {
if let Err(err) = socket.set_identity(identity) {
return Err(Error::CreateSocketFailed(name, err));
}
}
let socket = Self::new_raw(ctx, name.clone(), kind, identity)?;

// One side of a socket must `bind()` to its endpoint, and the other
// side must `connect()` to the same endpoint. The `bind()` side
Expand Down Expand Up @@ -88,6 +77,57 @@ impl Socket {
})
}

pub fn new_pair(
session: Session,
ctx: zmq::Context,
name: String,
identity: Option<&[u8]>,
endpoint: String,
bind: bool,
) -> Result<Self, Error> {
let socket = Self::new_raw(ctx, name.clone(), zmq::PAIR, identity)?;

if bind {
trace!("Binding to ZeroMQ '{}' socket at {}", name, endpoint);
if let Err(err) = socket.bind(&endpoint) {
return Err(Error::SocketBindError(name, endpoint, err));
}
} else {
trace!("Connecting to ZeroMQ '{}' socket at {}", name, endpoint);
if let Err(err) = socket.connect(&endpoint) {
return Err(Error::SocketConnectError(name, endpoint, err));
}
}

Ok(Self {
socket,
session,
name,
})
}

fn new_raw(
ctx: zmq::Context,
name: String,
kind: zmq::SocketType,
identity: Option<&[u8]>,
) -> Result<zmq::Socket, Error> {
// Create the underlying ZeroMQ socket
let socket = match ctx.socket(kind) {
Ok(s) => s,
Err(err) => return Err(Error::CreateSocketFailed(name, err)),
};

// Set the socket's identity, if supplied
if let Some(identity) = identity {
if let Err(err) = socket.set_identity(identity) {
return Err(Error::CreateSocketFailed(name, err));
}
}

Ok(socket)
}

/// Receive a message from the socket.
///
/// **Note**: This will block until a message is delivered on the socket.
Expand Down
Loading

0 comments on commit 4be01aa

Please sign in to comment.