Skip to content

Commit

Permalink
Fix memory leak if tasks contain wakers
Browse files Browse the repository at this point in the history
Closes #30
  • Loading branch information
loyd committed Aug 3, 2024
1 parent 53137e3 commit 895d55a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 14 deletions.
46 changes: 32 additions & 14 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub(crate) struct Header {

impl Header {
/// Construct a new waker.
pub(crate) fn new(shared: Arc<Shared>, index: usize) -> Self {
fn new(shared: Arc<Shared>, index: usize) -> Self {
Self {
shared,
index,
Expand Down Expand Up @@ -200,15 +200,9 @@ impl<T> Storage<T> {
None => return false,
};

// SAFETY: We have mutable access to the given entry, but we are careful
// not to dereference the header mutably, since that might be shared.
unsafe {
let value = match *ptr::addr_of_mut!((*task.as_ptr()).entry) {
ref mut value @ Entry::Some(..) => value,
_ => return false,
};

*value = Entry::Vacant(self.next);
// SAFETY: The `task` pointer is valid, since we got it from the slab.
if !unsafe { make_slot_vacant(task, self.next) } {
return false;
}

self.len -= 1;
Expand All @@ -221,11 +215,18 @@ impl<T> Storage<T> {
// SAFETY: We're just decrementing the reference count of each entry
// before dropping the storage of the slab.
unsafe {
for &entry in &self.tasks {
if entry.as_ref().header.decrement_ref() {
for &task in &self.tasks {
// We must drop a task's entry _before_ decrementing the reference counter
// because the task can be accessed by wakers in parallel now.
//
// Also, we violate the linked list of vacant slots by passing `0` here
// because the whole `tasks` vector will be cleared below anyway.
make_slot_vacant(task, 0);

if task.as_ref().header.decrement_ref() {
// SAFETY: We're the only ones holding a reference to the
// entry, so it's safe to drop it.
_ = Box::from_raw(entry.as_ptr());
// task, so it's safe to drop it.
_ = Box::from_raw(task.as_ptr());
}
}

Expand Down Expand Up @@ -263,6 +264,23 @@ impl<T> Storage<T> {
}
}

/// Returns `true` if the entry was removed, `false` otherwise.
///
/// # Safety
/// The `task` pointer must point to a valid entry.
unsafe fn make_slot_vacant<T>(task: NonNull<Task<T>>, next: usize) -> bool {
// SAFETY: We have mutable access to the given entry, but we are careful
// not to dereference the header mutably, since that might be shared.
let entry = unsafe { &mut *ptr::addr_of_mut!((*task.as_ptr()).entry) };

if !matches!(entry, Entry::Some(_)) {
return false;
}

*entry = Entry::Vacant(next);
true
}

impl<T> Default for Storage<T> {
fn default() -> Self {
Self::new()
Expand Down
49 changes: 49 additions & 0 deletions tests/stream_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#![cfg(feature = "futures-rs")]

use std::{
pin::Pin,
sync::{atomic, Arc},
task,
};

use tokio_stream::iter;
use unicycle::StreamsUnordered;

Expand All @@ -19,3 +25,46 @@ async fn test_unicycle_streams() {

assert_eq!(vec![5, 1, 6, 2, 7, 3, 8, 4], received);
}

// See #30 for details.
#[tokio::test]
async fn test_drop_with_stored_waker() {
struct Testee {
waker: Option<task::Waker>,
dropped: Arc<atomic::AtomicBool>,
}

impl futures::Stream for Testee {
type Item = u32;

fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Option<u32>> {
println!("testee polled");
unsafe { self.get_unchecked_mut() }.waker = Some(cx.waker().clone());
task::Poll::Pending
}
}

impl Drop for Testee {
fn drop(&mut self) {
println!("testee dropped");
self.dropped.store(true, atomic::Ordering::SeqCst);
}
}

let mut streams = StreamsUnordered::new();

let dropped = Arc::new(atomic::AtomicBool::new(false));
streams.push(Testee {
waker: None,
dropped: dropped.clone(),
});

{
let fut = streams.next();
let res = futures::future::poll_immediate(fut).await;
assert!(res.is_none());
}

drop(streams);
assert!(dropped.load(atomic::Ordering::SeqCst));
}

0 comments on commit 895d55a

Please sign in to comment.