Skip to content

Commit

Permalink
UnorderedReceiver can be closed now
Browse files Browse the repository at this point in the history
This is required to properly implement resource cleanup and install debugging infrastructure to Gateway.

One issue we noticed while working on it, is that it is impossible to distinguish between messages awaiting receive and streams that have been fully exhausted.

After this change, observers can check `is_closed` status and exclude this stream.

In the future, when resource cleanup will be a thing, such receivers can be safely removed from the channel map.
  • Loading branch information
akoshelev committed Oct 16, 2023
1 parent c75dfd7 commit f63cbd3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 30 deletions.
123 changes: 97 additions & 26 deletions src/helpers/buffers/unordered_receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use std::{
mem::take,
num::NonZeroUsize,
pin::Pin,
task::{Context, Poll},
task::{ready, Context, Poll},
};

use futures::{task::Waker, Future, Stream};
use futures::{stream::Fuse, task::Waker, Future, Stream};
use futures_util::StreamExt;
use generic_array::GenericArray;
use typenum::Unsigned;

Expand Down Expand Up @@ -105,6 +106,10 @@ impl Spare {
};
Some(m)
}
/// Returns `true` if there are no bytes currently awaiting a read.
fn is_empty(&self) -> bool {
self.offset == self.buf.len()
}
}

pub struct OperatingState<S, C>
Expand All @@ -113,7 +118,7 @@ where
C: AsRef<[u8]>,
{
/// The stream we're reading from.
stream: Pin<Box<S>>,
stream: Pin<Box<Fuse<S>>>,
/// The absolute index of the next value that will be received.
next: usize,
/// The underlying stream can provide chunks of data larger than a single
Expand Down Expand Up @@ -204,29 +209,54 @@ where
/// Poll for the next record. This should only be invoked when
/// the future for the next message is polled.
fn poll_next<M: Message>(&mut self, cx: &mut Context<'_>) -> Poll<Result<M, Error>> {
if let Some(m) = self.spare.read() {
self.wake_next();
return Poll::Ready(Ok(m));
}

loop {
match self.stream.as_mut().poll_next(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(b)) => {
if let Some(m) = self.spare.extend(b.as_ref()) {
self.wake_next();
return Poll::Ready(Ok(m));
// If spare has enough data for us, poll it first
// otherwise, poll the underlying stream until it returns pending or it provides enough
// data to return a value.

let message = self.spare.read::<M>();
let next = match message {
Some(m) => {
// this check exists to make sure the inner stream is eventually moved to
// the closed state. We don't want to poll it too often, but we also need to know
// when it is done and `UnorderedReceiver` can be dropped.
if self.spare.is_empty() {
// we don't want to be woken up here, control loop is driven by the client.
// They decide when they want the next message and must issue a `poll` for it.

// TODO: https://github.com/rust-lang/rust/issues/98286
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
match self.stream.as_mut().poll_next(&mut cx) {
Poll::Ready(Some(bytes)) => {
// Spare is empty because of the check above.
self.spare.replace(bytes.as_ref());
}
Poll::Ready(None) | Poll::Pending => {}
}
}
Poll::Ready(None) => {
return Poll::Ready(Err(Error::EndOfStream {
record_id: RecordId::from(self.next),
}));
}

Poll::Ready(Ok(m))
}
None => loop {
match ready!(self.stream.as_mut().poll_next(cx)) {
Some(bytes) => {
if let Some(m) = self.spare.extend(bytes.as_ref()) {
break Poll::Ready(Ok(m));
}
}
None => {
break Poll::Ready(Err(Error::EndOfStream {
record_id: RecordId::from(self.next),
}));
}
}
},
};

if next.is_ready() {
self.wake_next();
}

next
}
}

Expand Down Expand Up @@ -254,13 +284,13 @@ where
/// # Panics
///
/// The `capacity` needs to be at least 2.
pub fn new(stream: Pin<Box<S>>, capacity: NonZeroUsize) -> Self {
pub fn new(stream: S, capacity: NonZeroUsize) -> Self {
// We use `c/2` as a divisor, so `c == 1` would be bad.
assert!(capacity.get() > 1, "a capacity of 1 is too small");
let wakers = vec![None; capacity.get()];
Self {
inner: Arc::new(Mutex::new(OperatingState {
stream,
stream: Box::pin(stream.fuse()),
next: 0,
spare: Spare::default(),
wakers,
Expand All @@ -284,6 +314,19 @@ where
_marker: PhantomData,
}
}

/// Returns `true` when this receiver is closed. Closed means the underlying stream is done and
/// there is no more data inside receiver's buffers.
///
/// Calling `poll_next` on closed receivers will result in [`EOS`] error.
///
/// [`EOS`]: crate::helpers::Error::EndOfStream
pub fn is_closed(&self) -> bool {
// If this function is ever called on the hot path, consider caching closed status.
// Closed streams cannot move back to open.
let inner = self.inner.lock().unwrap();
inner.stream.is_done() && inner.spare.is_empty()
}
}

impl<S, C> Clone for UnorderedReceiver<S, C>
Expand Down Expand Up @@ -317,7 +360,7 @@ mod test {

use crate::{
ff::{Field, Fp31, Fp32BitPrime, Serializable},
helpers::buffers::unordered_receiver::UnorderedReceiver,
helpers::{buffers::unordered_receiver::UnorderedReceiver, Error::EndOfStream},
};

fn receiver<I, T>(it: I) -> UnorderedReceiver<impl Stream<Item = T>, T>
Expand All @@ -328,7 +371,7 @@ mod test {
{
// Use a small capacity so that we can overflow it easily.
let capacity = NonZeroUsize::new(3).unwrap();
UnorderedReceiver::new(Box::pin(iter(it)), capacity)
UnorderedReceiver::new(iter(it), capacity)
}

#[cfg(not(feature = "shuttle"))]
Expand Down Expand Up @@ -510,4 +553,32 @@ mod test {
}
});
}

#[test]
fn close() {
const DATA: &[u8] = &[1u8, 2, 3];
run(|| async move {
let recv = receiver([DATA]);
for i in 0..DATA.len() {
assert!(!recv.is_closed());
let _: Fp31 = recv.recv(i).await.unwrap();
}

assert!(recv.is_closed());
});
}

#[test]
fn end_of_stream() {
const DATA: &[u8] = &[1u8, 2, 3, 4, 5];
run(|| async move {
let recv = receiver([DATA]);
let _: Fp32BitPrime = recv.recv(0u8).await.unwrap();

assert!(matches!(
recv.recv::<Fp32BitPrime, _>(1u8).await,
Err(EndOfStream { .. })
));
});
}
}
6 changes: 2 additions & 4 deletions src/helpers/gateway/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ impl<T: Transport> RoleResolvingTransport<T> {
);

UnorderedReceiver::new(
Box::pin(
self.inner
.receive(peer, (self.query_id, channel_id.gate.clone())),
),
self.inner
.receive(peer, (self.query_id, channel_id.gate.clone())),
self.config.active_work(),
)
}
Expand Down

0 comments on commit f63cbd3

Please sign in to comment.