diff --git a/client/utils/src/mpsc.rs b/client/utils/src/mpsc.rs index 3f783b10060bd..7e06bd203b010 100644 --- a/client/utils/src/mpsc.rs +++ b/client/utils/src/mpsc.rs @@ -141,6 +141,7 @@ impl TracingUnboundedReceiver { impl Drop for TracingUnboundedReceiver { fn drop(&mut self) { + // Close the channel to prevent any further messages to be sent into the channel self.close(); // the number of messages about to be dropped let count = self.inner.len(); @@ -150,6 +151,10 @@ impl Drop for TracingUnboundedReceiver { .with_label_values(&[self.name, "dropped"]) .inc_by(count.saturated_into()); } + // Drain all the pending messages in the channel since they can never be accessed, + // this can be removed once https://github.com/smol-rs/async-channel/issues/23 is + // resolved + while let Ok(_) = self.inner.try_recv() {} } } @@ -177,3 +182,22 @@ impl FusedStream for TracingUnboundedReceiver { self.inner.is_terminated() } } + +#[cfg(test)] +mod tests { + use super::tracing_unbounded; + use async_channel::{self, RecvError, TryRecvError}; + + #[test] + fn test_tracing_unbounded_receiver_drop() { + let (tracing_unbounded_sender, tracing_unbounded_receiver) = + tracing_unbounded("test-receiver-drop", 10); + let (tx, rx) = async_channel::unbounded::(); + + tracing_unbounded_sender.unbounded_send(tx).unwrap(); + drop(tracing_unbounded_receiver); + + assert_eq!(rx.try_recv(), Err(TryRecvError::Closed)); + assert_eq!(rx.recv_blocking(), Err(RecvError)); + } +}