Skip to content

Commit

Permalink
Added Receiver support to broadcast channels
Browse files Browse the repository at this point in the history
I originally thought the need to keep everything lock-step was too much
of a burden, but I realized by removing the usage of mpsc, I could
support a Receiver interface by using send_async in an on_receive_async
callback. While this adds overhead if the receiver is then turned into a
callback, that's not the intended workflow -- and it still works just
fine, just has a little more overhead.
  • Loading branch information
ecton committed Jan 22, 2025
1 parent 5741d4e commit af1c026
Showing 1 changed file with 50 additions and 10 deletions.
60 changes: 50 additions & 10 deletions src/reactive/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ use std::fmt::{self, Debug};
use std::future::Future;
use std::ops::ControlFlow;
use std::pin::Pin;
use std::sync::{mpsc, Arc};
use std::sync::Arc;
use std::task::{ready, Context, Poll, Waker};

use builder::Builder;
Expand Down Expand Up @@ -412,8 +412,8 @@ impl<T> Debug for SingleCallback<T> {

enum BroadcastCallback<T> {
Blocking {
sender: mpsc::SyncSender<(T, Waker)>,
result: mpsc::Receiver<()>,
sender: Sender<(T, Autowaker)>,
result: Receiver<()>,
},
NonBlocking(Box<dyn AnyChannelCallback<T>>),
}
Expand All @@ -425,10 +425,10 @@ impl<T> BroadcastCallback<T> {
where
T: Send + 'static,
{
let (value_sender, value_receiver) = mpsc::sync_channel::<(T, Waker)>(1);
let (result_sender, result_receiver) = mpsc::sync_channel(1);
let (value_sender, value_receiver) = bounded::<(T, Autowaker)>(1);
let (result_sender, result_receiver) = bounded(1);
std::thread::spawn(move || {
while let Ok((value, waker)) = value_receiver.recv() {
while let Some((value, waker)) = value_receiver.receive() {
if let Ok(()) = cb(value) {
if result_sender.send(()).is_err() {
return;
Expand Down Expand Up @@ -587,12 +587,12 @@ where
else {
unreachable!("valid state");
};
match result.try_recv() {
match result.try_receive() {
Ok(()) => {
self.next_recipient += 1;
}
Err(mpsc::TryRecvError::Empty) => return Poll::Pending,
Err(mpsc::TryRecvError::Disconnected) => {
Err(TryReceiveError::Empty) => return Poll::Pending,
Err(TryReceiveError::Disconnected) => {
data.behavior.0.remove(self.next_recipient);
}
}
Expand Down Expand Up @@ -620,7 +620,7 @@ where
BroadcastCallback::Blocking { sender, .. } => {
if let Ok(()) = sender.send((
this.value.next().expect("enough value clones"),
cx.waker().clone(),
Autowaker(Some(cx.waker().clone())),
)) {
this.current_is_blocking = true;
drop(data_mutex);
Expand Down Expand Up @@ -897,6 +897,25 @@ where
self.data.force_send_inner(value, channel_id(&self.data))
}

/// Creates a new receiver for this channel.
///
/// All receivers and callbacks must receive each value before the next
/// value is able to be received.
#[must_use]
pub fn create_receiver(&self) -> Receiver<T> {
let (sender, receiver) = bounded(1);
self.on_receive_async_try(move |value| {
let sender = sender.clone();
async move {
sender
.send_async(value)
.await
.map_err(|_| CallbackDisconnected)
}
});
receiver
}

/// Invokes `on_receive` each time a value is sent to this channel.
///
/// This function assumes `on_receive` may block while waiting on another
Expand Down Expand Up @@ -1464,6 +1483,27 @@ pub enum TryReceiveError {
Disconnected,
}

struct Autowaker(Option<Waker>);

impl Autowaker {
fn wake_by_ref(&mut self) {
let Some(waker) = self.0.take() else {
return;
};
waker.wake();
}

fn wake(mut self) {
self.wake_by_ref();
}
}

impl Drop for Autowaker {
fn drop(&mut self) {
self.wake_by_ref();
}
}

#[test]
fn channel_basics() {
let (result_sender, result_receiver) = unbounded();
Expand Down

0 comments on commit af1c026

Please sign in to comment.