Skip to content

Commit

Permalink
Add support for alternative RegistrationFile (#576)
Browse files Browse the repository at this point in the history
Closes #563 
Joint work with Lionel.

This PR implements the alternative `RegistrationFile` approach outlined in JEP 66 that allows for a "handshake" to occur between the client and server on startup. In particular, it allows the server to be in charge of picking the ports, and immediately binds to them as it picks them, avoiding any race conditions here.

If the `--connection_file` argument provided to ark can parse into this structure:

```rust
pub struct RegistrationFile {
    /// The transport type to use for ZeroMQ; generally "tcp"
    pub transport: String,

    /// The signature scheme to use for messages; generally "hmac-sha256"
    pub signature_scheme: String,

    /// The IP address to bind to
    pub ip: String,

    /// The HMAC-256 signing key, or an empty string for an unauthenticated
    /// connection
    pub key: String,

    /// ZeroMQ port: Registration messages (handshake)
    pub registration_port: u16,
}
```

Then we assume we are going to be using the handshake method of connecting. Otherwise we parse into the typical `ConnectionFile` structure and assume the Client picked the ports.

We expect that the Client _binds_ to a `zmq::REP` socket on `registration_port`. Ark, as the Server, will then _connect_ to this `registration_port` as a `zmq::REQ` socket.

Ark will pick ports, bind to them, and send this message over the registration socket:

```rust
pub struct HandshakeRequest {
    /// ZeroMQ port: Control channel (kernel interrupts)
    pub control_port: u16,

    /// ZeroMQ port: Shell channel (execution, completion)
    pub shell_port: u16,

    /// ZeroMQ port: Standard input channel (prompts)
    pub stdin_port: u16,

    /// ZeroMQ port: IOPub channel (broadcasts input/output)
    pub iopub_port: u16,

    /// ZeroMQ port: Heartbeat messages (echo)
    pub hb_port: u16,
}
```

Ark will then _immediately_ block, waiting for this `HandshakeReply`:

```rust
pub struct HandshakeReply {
    /// The execution status ("ok" or "error")
    pub status: Status,
}
```

This is just a receipt from the Client that confirms that it received the socket information.

If ark does not receive this reply after a few seconds, it will shut itself down.

Ark disconnects from the registration socket after receiving the `HandshakeReply`, and the kernel proceeds to start up.

---

* Draft registration file

* Draft registration handshake

* Expose `kernel::read_connection()` and make `kernel::connect()` take connection files again

* Ensure ip/transport/signature are always aligned

* Even better practice to ensure `endpoint`s are right

* Must run `start_kernel()` on a separate thread to be able to perform the handshake

* Call `kernel::connect()` from its own thread in amalthea client tests too

* Add a `TODO!` about the currently required sleep

* Add a `TODO!` about echo tests

* Add link to JEP 66

* Downgrade TODO

* Remove portpicker dependency

* Fix typo in `HandshakeRequest`

Co-authored-by: Davis Vaughan <davis@rstudio.com>

* Return frontend directly

* Comment on timeout

---------

Co-authored-by: Lionel Henry <lionel@posit.co>
Co-authored-by: Davis Vaughan <davis@posit.co>
  • Loading branch information
3 people authored Oct 10, 2024
1 parent 3c8b072 commit 8194c2a
Show file tree
Hide file tree
Showing 17 changed files with 1,018 additions and 678 deletions.
10 changes: 0 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/amalthea/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ hex = "0.4.3"
hmac = "0.12.1"
log = "0.4.17"
nix = "0.26.2"
portpicker = "0.1.1"
rand = "0.8.5"
serde = { version = "1.0.154", features = ["derive"] }
serde_json = { version = "1.0.94", features = ["preserve_order"]}
Expand Down
12 changes: 12 additions & 0 deletions crates/amalthea/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub enum Error {
InvalidCommMessage(String, String, String),
InvalidInputRequest(String),
InvalidConsoleInput(String),
Anyhow(anyhow::Error),
}

impl std::error::Error for Error {}
Expand Down Expand Up @@ -197,6 +198,9 @@ impl fmt::Display for Error {
Error::InvalidConsoleInput(message) => {
write!(f, "{message}")
},
Error::Anyhow(err) => {
write!(f, "{err:?}")
},
}
}
}
Expand All @@ -206,3 +210,11 @@ impl<T: std::fmt::Debug> From<SendError<T>> for Error {
Self::SendError(format!("Could not send {:?} to channel.", err.0))
}
}

#[macro_export]
macro_rules! anyhow {
($($rest: expr),*) => {{
let message = anyhow::anyhow!($($rest, )*);
crate::error::Error::Anyhow(message)
}}
}
219 changes: 138 additions & 81 deletions crates/amalthea/src/fixtures/dummy_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
*/

use assert_matches::assert_matches;
use rand::Rng;
use serde_json::Value;

use crate::connection_file::ConnectionFile;
use crate::registration_file::RegistrationFile;
use crate::session::Session;
use crate::socket::socket::Socket;
use crate::wire::execute_input::ExecuteInput;
use crate::wire::execute_request::ExecuteRequest;
use crate::wire::handshake_reply::HandshakeReply;
use crate::wire::input_reply::InputReply;
use crate::wire::jupyter_message::JupyterMessage;
use crate::wire::jupyter_message::Message;
Expand All @@ -22,126 +25,190 @@ use crate::wire::status::ExecutionState;
use crate::wire::stream::Stream;
use crate::wire::wire_message::WireMessage;

pub struct DummyConnection {
pub registration_socket: Socket,
pub ctx: zmq::Context,
pub session: Session,
pub key: String,
pub ip: String,
pub transport: String,
pub signature_scheme: String,
}

pub struct DummyFrontend {
pub _control_socket: Socket,
pub shell_socket: Socket,
pub iopub_socket: Socket,
pub stdin_socket: Socket,
pub heartbeat_socket: Socket,
session: Session,
key: String,
control_port: u16,
shell_port: u16,
iopub_port: u16,
stdin_port: u16,
heartbeat_port: u16,
}

pub struct ExecuteRequestOptions {
pub allow_stdin: bool,
}

impl DummyFrontend {
impl DummyConnection {
pub fn new() -> Self {
use rand::Rng;

// Create a random HMAC key for signing messages.
let key_bytes = rand::thread_rng().gen::<[u8; 16]>();
let key = hex::encode(key_bytes);

// Create a random socket identity for the shell and stdin sockets. Per
// the Jupyter specification, these must share a ZeroMQ identity.
let shell_id = rand::thread_rng().gen::<[u8; 16]>();

// Create a new kernel session from the key
let session = Session::create(key.clone()).unwrap();
let session = Session::create(&key).unwrap();

// Create a zmq context for all sockets we create in this session
let ctx = zmq::Context::new();

let control_port = portpicker::pick_unused_port().unwrap();
let control = Socket::new(
let ip = String::from("127.0.0.1");
let transport = String::from("tcp");
let signature_scheme = String::from("hmac-sha256");

// Bind to a random port using `0`
let registration_socket = Socket::new(
session.clone(),
ctx.clone(),
String::from("Registration"),
zmq::REP,
None,
Self::endpoint_from_parts(&transport, &ip, 0),
)
.unwrap();

Self {
registration_socket,
ctx,
session,
key,
ip,
transport,
signature_scheme,
}
}

/// Gets a connection file for the Amalthea kernel that will connect it to
/// this synthetic frontend. Uses a handshake through a registration
/// file to avoid race conditions related to port binding.
pub fn get_connection_files(&self) -> (ConnectionFile, RegistrationFile) {
let registration_file = RegistrationFile {
ip: self.ip.clone(),
transport: self.transport.clone(),
signature_scheme: self.signature_scheme.clone(),
key: self.key.clone(),
registration_port: crate::kernel::port_from_socket(&self.registration_socket).unwrap(),
};

let connection_file = registration_file.as_connection_file();

(connection_file, registration_file)
}

fn endpoint(&self, port: u16) -> String {
Self::endpoint_from_parts(&self.transport, &self.ip, port)
}

fn endpoint_from_parts(transport: &str, ip: &str, port: u16) -> String {
format!("{transport}://{ip}:{port}")
}
}

impl DummyFrontend {
pub fn from_connection(connection: DummyConnection) -> Self {
// Wait to receive the handshake request so we know what ports to connect on.
// Note that `recv()` times out.
let message = Self::recv(&connection.registration_socket);
let handshake = assert_matches!(message, Message::HandshakeRequest(message) => {
message.content
});

// Immediately send back a handshake reply so the kernel can start up
Self::send(
&connection.registration_socket,
&connection.session,
HandshakeReply { status: Status::Ok },
);

// Create a random socket identity for the shell and stdin sockets. Per
// the Jupyter specification, these must share a ZeroMQ identity.
let shell_id = rand::thread_rng().gen::<[u8; 16]>();

let _control_socket = Socket::new(
connection.session.clone(),
connection.ctx.clone(),
String::from("Control"),
zmq::DEALER,
None,
format!("tcp://127.0.0.1:{}", control_port),
connection.endpoint(handshake.control_port),
)
.unwrap();

let shell_port = portpicker::pick_unused_port().unwrap();
let shell = Socket::new(
session.clone(),
ctx.clone(),
let shell_socket = Socket::new(
connection.session.clone(),
connection.ctx.clone(),
String::from("Shell"),
zmq::DEALER,
Some(&shell_id),
format!("tcp://127.0.0.1:{}", shell_port),
connection.endpoint(handshake.shell_port),
)
.unwrap();

let iopub_port = portpicker::pick_unused_port().unwrap();
let iopub = Socket::new(
session.clone(),
ctx.clone(),
let iopub_socket = Socket::new(
connection.session.clone(),
connection.ctx.clone(),
String::from("IOPub"),
zmq::SUB,
None,
format!("tcp://127.0.0.1:{}", iopub_port),
connection.endpoint(handshake.iopub_port),
)
.unwrap();

let stdin_port = portpicker::pick_unused_port().unwrap();
let stdin = Socket::new(
session.clone(),
ctx.clone(),
// Subscribe to IOPub! Server is the one that sent us this port,
// so its already connected on its end.
iopub_socket.subscribe().unwrap();

let stdin_socket = Socket::new(
connection.session.clone(),
connection.ctx.clone(),
String::from("Stdin"),
zmq::DEALER,
Some(&shell_id),
format!("tcp://127.0.0.1:{}", stdin_port),
connection.endpoint(handshake.stdin_port),
)
.unwrap();

let heartbeat_port = portpicker::pick_unused_port().unwrap();
let heartbeat = Socket::new(
session.clone(),
ctx.clone(),
let heartbeat_socket = Socket::new(
connection.session.clone(),
connection.ctx.clone(),
String::from("Heartbeat"),
zmq::REQ,
None,
format!("tcp://127.0.0.1:{}", heartbeat_port),
connection.endpoint(handshake.hb_port),
)
.unwrap();

// TODO!: Without this sleep, `IOPub` `Busy` messages sporadically
// don't arrive when running integration tests. I believe this is a result
// of PUB sockets dropping messages while in a "mute" state (i.e. no subscriber
// connected yet). Even though we run `iopub_socket.subscribe()` to subscribe,
// it seems like we can return from this function even before our socket
// has fully subscribed, causing messages to get dropped.
// https://libzmq.readthedocs.io/en/latest/zmq_socket.html
std::thread::sleep(std::time::Duration::from_millis(500));

Self {
session,
key,
control_port,
_control_socket: control,
shell_port,
shell_socket: shell,
iopub_port,
iopub_socket: iopub,
stdin_port,
stdin_socket: stdin,
heartbeat_port,
heartbeat_socket: heartbeat,
_control_socket,
shell_socket,
iopub_socket,
stdin_socket,
heartbeat_socket,
session: connection.session,
}
}

/// Completes initialization of the frontend (usually done after the kernel
/// is ready and connected)
pub fn complete_initialization(&self) {
self.iopub_socket.subscribe().unwrap();
}

/// Sends a Jupyter message on the Shell socket; returns the ID of the newly
/// created message
pub fn send_shell<T: ProtocolMessage>(&self, msg: T) -> String {
let message = JupyterMessage::create(msg, None, &self.session);
let id = message.header.msg_id.clone();
message.send(&self.shell_socket).unwrap();
id
Self::send(&self.shell_socket, &self.session, msg)
}

pub fn send_execute_request(&self, code: &str, options: ExecuteRequestOptions) -> String {
Expand All @@ -157,11 +224,17 @@ impl DummyFrontend {

/// Sends a Jupyter message on the Stdin socket
pub fn send_stdin<T: ProtocolMessage>(&self, msg: T) {
let message = JupyterMessage::create(msg, None, &self.session);
message.send(&self.stdin_socket).unwrap();
Self::send(&self.stdin_socket, &self.session, msg);
}

pub fn recv(&self, socket: &Socket) -> Message {
fn send<T: ProtocolMessage>(socket: &Socket, session: &Session, msg: T) -> String {
let message = JupyterMessage::create(msg, None, session);
let id = message.header.msg_id.clone();
message.send(socket).unwrap();
id
}

pub fn recv(socket: &Socket) -> Message {
// It's important to wait with a timeout because the kernel thread might
// have panicked, preventing it from sending the expected message. The
// tests would then hang indefinitely.
Expand All @@ -177,17 +250,17 @@ impl DummyFrontend {

/// Receives a Jupyter message from the Shell socket
pub fn recv_shell(&self) -> Message {
self.recv(&self.shell_socket)
Self::recv(&self.shell_socket)
}

/// Receives a Jupyter message from the IOPub socket
pub fn recv_iopub(&self) -> Message {
self.recv(&self.iopub_socket)
Self::recv(&self.iopub_socket)
}

/// Receives a Jupyter message from the Stdin socket
pub fn recv_stdin(&self) -> Message {
self.recv(&self.stdin_socket)
Self::recv(&self.stdin_socket)
}

/// Receive from Shell and assert `ExecuteReply` message.
Expand Down Expand Up @@ -334,22 +407,6 @@ impl DummyFrontend {
self.heartbeat_socket.send(msg).unwrap();
}

/// Gets a connection file for the Amalthea kernel that will connect it to
/// this synthetic frontend.
pub fn get_connection_file(&self) -> ConnectionFile {
ConnectionFile {
control_port: self.control_port,
shell_port: self.shell_port,
stdin_port: self.stdin_port,
iopub_port: self.iopub_port,
hb_port: self.heartbeat_port,
transport: String::from("tcp"),
signature_scheme: String::from("hmac-sha256"),
ip: String::from("127.0.0.1"),
key: self.key.clone(),
}
}

/// Asserts that no socket has incoming data
pub fn assert_no_incoming(&mut self) {
let mut has_incoming = false;
Expand Down
Loading

0 comments on commit 8194c2a

Please sign in to comment.