Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop all messages in bounded channel when destroying the last receiver #108164

Merged
merged 4 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 108 additions & 25 deletions library/std/src/sync/mpmc/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::utils::{Backoff, CachePadded};
use super::waker::SyncWaker;

use crate::cell::UnsafeCell;
use crate::mem::MaybeUninit;
use crate::mem::{self, MaybeUninit};
use crate::ptr;
use crate::sync::atomic::{self, AtomicUsize, Ordering};
use crate::time::Instant;
Expand All @@ -25,7 +25,8 @@ struct Slot<T> {
/// The current stamp.
stamp: AtomicUsize,

/// The message in this slot.
/// The message in this slot. Either read out in `read` or dropped through
/// `discard_all_messages`.
msg: UnsafeCell<MaybeUninit<T>>,
}

Expand Down Expand Up @@ -439,21 +440,123 @@ impl<T> Channel<T> {
Some(self.cap)
}

/// Disconnects the channel and wakes up all blocked senders and receivers.
/// Disconnects senders and wakes up all blocked receivers.
///
/// Returns `true` if this call disconnected the channel.
pub(crate) fn disconnect(&self) -> bool {
pub(crate) fn disconnect_senders(&self) -> bool {
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);

if tail & self.mark_bit == 0 {
self.senders.disconnect();
self.receivers.disconnect();
true
} else {
false
}
}

/// Disconnects receivers and wakes up all blocked senders.
///
/// Returns `true` if this call disconnected the channel.
///
/// # Safety
/// May only be called once upon dropping the last receiver. The
/// destruction of all other receivers must have been observed with acquire
/// ordering or stronger.
pub(crate) unsafe fn disconnect_receivers(&self) -> bool {
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
let disconnected = if tail & self.mark_bit == 0 {
self.senders.disconnect();
true
} else {
false
};

self.discard_all_messages(tail);
disconnected
}

/// Discards all messages.
///
/// `tail` should be the current (and therefore last) value of `tail`.
///
/// # Safety
/// This method must only be called when dropping the last receiver. The
/// destruction of all other receivers must have been observed with acquire
/// ordering or stronger.
unsafe fn discard_all_messages(&self, tail: usize) {
debug_assert!(self.is_disconnected());

/// Use a helper struct with a custom `Drop` to ensure all messages are
/// dropped, even if a destructor panicks.
struct DiscardState<'a, T> {
channel: &'a Channel<T>,
head: usize,
tail: usize,
backoff: Backoff,
}

impl<'a, T> DiscardState<'a, T> {
fn discard(&mut self) {
loop {
// Deconstruct the head.
let index = self.head & (self.channel.mark_bit - 1);
let lap = self.head & !(self.channel.one_lap - 1);

// Inspect the corresponding slot.
debug_assert!(index < self.channel.buffer.len());
let slot = unsafe { self.channel.buffer.get_unchecked(index) };
let stamp = slot.stamp.load(Ordering::Acquire);

// If the stamp is ahead of the head by 1, we may drop the message.
if self.head + 1 == stamp {
self.head = if index + 1 < self.channel.cap {
// Same lap, incremented index.
// Set to `{ lap: lap, mark: 0, index: index + 1 }`.
self.head + 1
} else {
// One lap forward, index wraps around to zero.
// Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
lap.wrapping_add(self.channel.one_lap)
};

// We updated the head, so even if this descrutor panics,
// we will not attempt to destroy the slot again.
unsafe {
(*slot.msg.get()).assume_init_drop();
}
// If the tail equals the head, that means the channel is empty.
} else if self.tail == self.head {
return;
// Otherwise, a sender is about to write into the slot, so we need
// to wait for it to update the stamp.
} else {
self.backoff.spin_heavy();
}
}
}
}

impl<'a, T> Drop for DiscardState<'a, T> {
fn drop(&mut self) {
self.discard();
joboet marked this conversation as resolved.
Show resolved Hide resolved
}
}

let mut state = DiscardState {
channel: self,
// Only receivers modify `head`, so since we are the last one,
// this value will not change and will not be observed (since
// no new messages can be sent after disconnection).
head: self.head.load(Ordering::Relaxed),
tail: tail & !self.mark_bit,
backoff: Backoff::new(),
};
state.discard();
// This point is only reached if no destructor panics, so all messages
// have already been dropped.
mem::forget(state);
}

/// Returns `true` if the channel is disconnected.
pub(crate) fn is_disconnected(&self) -> bool {
self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
Expand Down Expand Up @@ -483,23 +586,3 @@ impl<T> Channel<T> {
head.wrapping_add(self.one_lap) == tail & !self.mark_bit
}
}

impl<T> Drop for Channel<T> {
fn drop(&mut self) {
// Get the index of the head.
let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);

// Loop over all slots that hold a message and drop them.
for i in 0..self.len() {
// Compute the index of the next slot holding a message.
let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };

unsafe {
debug_assert!(index < self.buffer.len());
let slot = self.buffer.get_unchecked_mut(index);
let msg = &mut *slot.msg.get();
msg.as_mut_ptr().drop_in_place();
}
}
}
}
4 changes: 2 additions & 2 deletions library/std/src/sync/mpmc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl<T> Drop for Sender<T> {
fn drop(&mut self) {
unsafe {
match &self.flavor {
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect()),
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect_senders()),
SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders()),
SenderFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
}
Expand Down Expand Up @@ -403,7 +403,7 @@ impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe {
match &self.flavor {
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect()),
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect_receivers()),
ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers()),
ReceiverFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
}
Expand Down
13 changes: 13 additions & 0 deletions library/std/src/sync/mpsc/sync_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::*;
use crate::env;
use crate::rc::Rc;
use crate::sync::mpmc::SendTimeoutError;
use crate::thread;
use crate::time::Duration;
Expand Down Expand Up @@ -656,3 +657,15 @@ fn issue_15761() {
repro()
}
}

#[test]
fn drop_unreceived() {
let (tx, rx) = sync_channel::<Rc<()>>(1);
let msg = Rc::new(());
let weak = Rc::downgrade(&msg);
assert!(tx.send(msg).is_ok());
drop(rx);
// Messages should be dropped immediately when the last receiver is destroyed.
assert!(weak.upgrade().is_none());
drop(tx);
}