diff --git a/crates/amalthea/src/kernel.rs b/crates/amalthea/src/kernel.rs index c1bb89d9f..8e9e06e28 100644 --- a/crates/amalthea/src/kernel.rs +++ b/crates/amalthea/src/kernel.rs @@ -366,34 +366,70 @@ impl Kernel { stdin_inbound_tx: Sender, stdin_outbound_rx: Receiver, ) { - let outbound_notif_poll_item = outbound_notif_socket.socket.as_poll_item(zmq::POLLIN); - let stdin_poll_item = stdin_socket.socket.as_poll_item(zmq::POLLIN); - - let mut poll_items = vec![ - outbound_notif_socket.socket.as_poll_item(zmq::POLLIN), - stdin_socket.socket.as_poll_item(zmq::POLLIN), - ]; + // Create poll items necessary to call `zmq_poll()` + let mut poll_items = { + let outbound_notif_poll_item = outbound_notif_socket.socket.as_poll_item(zmq::POLLIN); + let stdin_poll_item = stdin_socket.socket.as_poll_item(zmq::POLLIN); + vec![outbound_notif_poll_item, stdin_poll_item] + }; - let has_outbound = || -> bool { - if outbound_notif_poll_item.is_readable() { + // This function checks that an outgoing message is ready to be + // read on an Amalthea channel. Returns the index of the hot + // channel (currently unused as we only have one channel at the + // moment). + let has_outbound = || -> Option { + if let Ok(n) = outbound_notif_socket.socket.poll(zmq::POLLIN, 0) { + if n == 0 { + return None; + } // Consume notification - let mut msg = zmq::Message::new(); - unwrap!(outbound_notif_socket.recv(&mut msg), Err(err) => { - log::warn!("Could not consume outbound notification socket: {}", err) + let bytes = unwrap!(outbound_notif_socket.socket.recv_bytes(0), Err(err) => { + log::warn!("Could not consume outbound notification socket: {}", err); + return None; + }); + + // Get index of hot channel + let index = unwrap!(bytes.try_into(), Err(_) => { + log::error!("Could not extract index from outbound notification"); + return None; }); - true + Some(usize::from_be_bytes(index)) } else { - false + None } }; - let forward_outbound = || -> anyhow::Result<()> { + // This function checks that a 0MQ message from the frontend is ready + let has_inbound = || -> bool { + match stdin_socket.socket.poll(zmq::POLLIN, 0) { + Ok(n) if n > 0 => true, + _ => false, + } + }; + + // Forward channel message from Amalthea to the frontend via the + // corresponding 0MQ socket. Should consume exactly 1 channel. The + // `_i` argument indicates which channel to consume. It's currently + // unused as we only manage one channel at the moment but that + // could change. + let forward_outbound = |_i: usize| -> anyhow::Result<()> { let msg = stdin_outbound_rx.recv()?; msg.send(&stdin_socket)?; + + // Send back a notification once the channel message is + // consumed. This way we keep the forwarding and notifier + // threads synchronised. + unwrap!( + outbound_notif_socket.send(zmq::Message::new()), + Err(err) => error!("While notifying back notifier thread: {}", err) + ); + Ok(()) }; + // Forward 0MQ message from the frontend to the corresponding + // Amalthea channel. let forward_inbound = || -> anyhow::Result<()> { let msg = Message::read_from_socket(&stdin_socket)?; stdin_inbound_tx.send(msg)?; @@ -410,9 +446,9 @@ impl Kernel { ); while n > 0 { - if has_outbound() { + if let Some(index) = has_outbound() { unwrap!( - forward_outbound(), + forward_outbound(index), Err(err) => error!("While forwarding outbound message: {}", err) ); @@ -420,7 +456,7 @@ impl Kernel { continue; } - if stdin_poll_item.is_readable() { + if has_inbound() { unwrap!( forward_inbound(), Err(err) => error!("While forwarding inbound message: {}", err) @@ -442,9 +478,21 @@ impl Kernel { } loop { - sel.ready(); + let i: usize = sel.ready(); + let i_bytes = i.to_be_bytes(); + + unwrap!( + notif_socket.send(zmq::Message::from(&i_bytes[..])), + Err(err) => error!("Couldn't notify 0MQ thread: {}", err) + ); + + // To keep things synchronised, wait to be notified that the + // channel message has been consumed before continuing the loop. unwrap!( - notif_socket.send(zmq::Message::new()), + { + let mut msg = zmq::Message::new(); + notif_socket.recv(&mut msg) + }, Err(err) => error!("Couldn't notify 0MQ thread: {}", err) ); }