From 895d55a9dc8b7fc8345df1f18775b4e41ad5dd7f Mon Sep 17 00:00:00 2001 From: Paul Loyd Date: Sat, 3 Aug 2024 19:34:56 +0400 Subject: [PATCH] Fix memory leak if tasks contain wakers Closes #30 --- src/task.rs | 46 ++++++++++++++++++++++++++++------------- tests/stream_test.rs | 49 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/src/task.rs b/src/task.rs index 209643e..f59cd9f 100644 --- a/src/task.rs +++ b/src/task.rs @@ -60,7 +60,7 @@ pub(crate) struct Header { impl Header { /// Construct a new waker. - pub(crate) fn new(shared: Arc, index: usize) -> Self { + fn new(shared: Arc, index: usize) -> Self { Self { shared, index, @@ -200,15 +200,9 @@ impl Storage { None => return false, }; - // SAFETY: We have mutable access to the given entry, but we are careful - // not to dereference the header mutably, since that might be shared. - unsafe { - let value = match *ptr::addr_of_mut!((*task.as_ptr()).entry) { - ref mut value @ Entry::Some(..) => value, - _ => return false, - }; - - *value = Entry::Vacant(self.next); + // SAFETY: The `task` pointer is valid, since we got it from the slab. + if !unsafe { make_slot_vacant(task, self.next) } { + return false; } self.len -= 1; @@ -221,11 +215,18 @@ impl Storage { // SAFETY: We're just decrementing the reference count of each entry // before dropping the storage of the slab. unsafe { - for &entry in &self.tasks { - if entry.as_ref().header.decrement_ref() { + for &task in &self.tasks { + // We must drop a task's entry _before_ decrementing the reference counter + // because the task can be accessed by wakers in parallel now. + // + // Also, we violate the linked list of vacant slots by passing `0` here + // because the whole `tasks` vector will be cleared below anyway. + make_slot_vacant(task, 0); + + if task.as_ref().header.decrement_ref() { // SAFETY: We're the only ones holding a reference to the - // entry, so it's safe to drop it. - _ = Box::from_raw(entry.as_ptr()); + // task, so it's safe to drop it. + _ = Box::from_raw(task.as_ptr()); } } @@ -263,6 +264,23 @@ impl Storage { } } +/// Returns `true` if the entry was removed, `false` otherwise. +/// +/// # Safety +/// The `task` pointer must point to a valid entry. +unsafe fn make_slot_vacant(task: NonNull>, next: usize) -> bool { + // SAFETY: We have mutable access to the given entry, but we are careful + // not to dereference the header mutably, since that might be shared. + let entry = unsafe { &mut *ptr::addr_of_mut!((*task.as_ptr()).entry) }; + + if !matches!(entry, Entry::Some(_)) { + return false; + } + + *entry = Entry::Vacant(next); + true +} + impl Default for Storage { fn default() -> Self { Self::new() diff --git a/tests/stream_test.rs b/tests/stream_test.rs index 377dd17..2b85cae 100644 --- a/tests/stream_test.rs +++ b/tests/stream_test.rs @@ -1,5 +1,11 @@ #![cfg(feature = "futures-rs")] +use std::{ + pin::Pin, + sync::{atomic, Arc}, + task, +}; + use tokio_stream::iter; use unicycle::StreamsUnordered; @@ -19,3 +25,46 @@ async fn test_unicycle_streams() { assert_eq!(vec![5, 1, 6, 2, 7, 3, 8, 4], received); } + +// See #30 for details. +#[tokio::test] +async fn test_drop_with_stored_waker() { + struct Testee { + waker: Option, + dropped: Arc, + } + + impl futures::Stream for Testee { + type Item = u32; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll> { + println!("testee polled"); + unsafe { self.get_unchecked_mut() }.waker = Some(cx.waker().clone()); + task::Poll::Pending + } + } + + impl Drop for Testee { + fn drop(&mut self) { + println!("testee dropped"); + self.dropped.store(true, atomic::Ordering::SeqCst); + } + } + + let mut streams = StreamsUnordered::new(); + + let dropped = Arc::new(atomic::AtomicBool::new(false)); + streams.push(Testee { + waker: None, + dropped: dropped.clone(), + }); + + { + let fut = streams.next(); + let res = futures::future::poll_immediate(fut).await; + assert!(res.is_none()); + } + + drop(streams); + assert!(dropped.load(atomic::Ordering::SeqCst)); +}