From 74e243958260f3a38b06f693f64b72596aad2d06 Mon Sep 17 00:00:00 2001 From: Matthias Einwag Date: Wed, 22 Jan 2020 00:13:56 -0800 Subject: [PATCH] Add task::scope This change adds `task::scope` as a mechanism for supporting structured concurrency as described in #1879. The change adds a `task::scope` function which will forcefully cancel all child tasks when the scope is exited, as well as a `task::scope_with_options` function which allows to override the default cancellation and drop behavior. The `scope` implementations makes use of 2 primitives: - CancellationToken: This allows to signal an arbitrary amount of tasks to cancel - WaitGroup: This allows to wait for outstanding tasks to complete Both primitives are implemented using mechansims and code from futures-intrusive. --- tokio/Cargo.toml | 4 +- tokio/src/task/mod.rs | 3 + tokio/src/task/scope/cancellation_token.rs | 255 +++++++ .../scope/intrusive_double_linked_list.rs | 628 ++++++++++++++++++ tokio/src/task/scope/manual_scope.rs | 513 ++++++++++++++ tokio/src/task/scope/mod.rs | 12 + tokio/src/task/scope/scope.rs | 354 ++++++++++ tokio/src/task/scope/scope_state.rs | 518 +++++++++++++++ tokio/src/task/scope/wait_group.rs | 271 ++++++++ tokio/tests/task_scope.rs | 336 ++++++++++ 10 files changed, 2893 insertions(+), 1 deletion(-) create mode 100644 tokio/src/task/scope/cancellation_token.rs create mode 100644 tokio/src/task/scope/intrusive_double_linked_list.rs create mode 100644 tokio/src/task/scope/manual_scope.rs create mode 100644 tokio/src/task/scope/mod.rs create mode 100644 tokio/src/task/scope/scope.rs create mode 100644 tokio/src/task/scope/scope_state.rs create mode 100644 tokio/src/task/scope/wait_group.rs create mode 100644 tokio/tests/task_scope.rs diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index d701e511654..6b74d14257b 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -45,6 +45,7 @@ full = [ "stream", "sync", "time", + "futures" ] blocking = ["rt-core"] @@ -100,6 +101,7 @@ pin-project-lite = "0.1.1" # Everything else is optional... fnv = { version = "1.0.6", optional = true } futures-core = { version = "0.3.0", optional = true } +futures = { version = "0.3.0", optional = true } lazy_static = { version = "1.0.2", optional = true } memchr = { version = "2.2", optional = true } mio = { version = "0.6.20", optional = true } @@ -123,7 +125,7 @@ optional = true [dev-dependencies] tokio-test = { version = "0.2.0" } -futures = { version = "0.3.0", features = ["async-await"] } +futures = { version = "0.3.0", features = ["async-await", "executor"] } proptest = "0.9.4" tempfile = "3.1.0" diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index 073215e6e69..d7aa0a07885 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -246,6 +246,9 @@ cfg_rt_core! { mod raw; use self::raw::RawTask; + mod scope; + pub use scope::{scope, scope_with_options, ScopeOptions, ScopedJoinHandle, ScopeHandle, ScopeCancelBehavior, ScopeDropBehavior}; + mod spawn; pub use spawn::spawn; diff --git a/tokio/src/task/scope/cancellation_token.rs b/tokio/src/task/scope/cancellation_token.rs new file mode 100644 index 00000000000..b9d0e464b57 --- /dev/null +++ b/tokio/src/task/scope/cancellation_token.rs @@ -0,0 +1,255 @@ +//! An asynchronously awaitable event for signalization between tasks + +use super::intrusive_double_linked_list::{LinkedList, ListNode}; +use std::{ + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +/// Tracks how the future had interacted with the event +#[derive(PartialEq)] +enum PollState { + /// The task has never interacted with the event. + New, + /// The task was added to the wait queue at the event. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// Tracks the WaitForCancellationFuture waiting state. +/// Access to this struct is synchronized through the mutex in the Event. +struct WaitQueueEntry { + /// The task handle of the waiting task + task: Option, + /// Current polling state + state: PollState, +} + +impl WaitQueueEntry { + /// Creates a new WaitQueueEntry + fn new() -> WaitQueueEntry { + WaitQueueEntry { + task: None, + state: PollState::New, + } + } +} + +/// Internal state of the `CancellationToken` pair above +struct CancellationTokenState { + is_cancelled: bool, + waiters: LinkedList, +} + +impl CancellationTokenState { + fn new(is_cancelled: bool) -> CancellationTokenState { + CancellationTokenState { + is_cancelled, + waiters: LinkedList::new(), + } + } + + fn cancel(&mut self) { + if self.is_cancelled != true { + self.is_cancelled = true; + + // Wakeup all waiters + // This happens inside the lock to make cancellation reliable + // If we would access waiters outside of the lock, the pointers + // may no longer be valid. + // Typically this shouldn't be an issue, since waking a task should + // only move it from the blocked into the ready state and not have + // further side effects. + + let waiters = self.waiters.take(); + + unsafe { + // Use a reverse iterator, so that the oldest waiter gets + // scheduled first + for waiter in waiters.into_reverse_iter() { + if let Some(handle) = (*waiter).task.take() { + handle.wake(); + } + (*waiter).state = PollState::Done; + } + } + } + } + + fn is_cancelled(&self) -> bool { + self.is_cancelled + } + + /// Checks if the cancellation has occured. If it is this returns immediately. + /// If the event isn't set, the WaitForCancellationFuture gets added to the wait + /// queue at the event, and will be signalled once ready. + /// This function is only safe as long as the `wait_node`s address is guaranteed + /// to be stable until it gets removed from the queue. + unsafe fn try_wait( + &mut self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + match wait_node.state { + PollState::New => { + if self.is_cancelled { + // The event is already signaled + wait_node.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + wait_node.task = Some(cx.waker().clone()); + wait_node.state = PollState::Waiting; + self.waiters.add_front(wait_node); + Poll::Pending + } + } + PollState::Waiting => { + // The WaitForCancellationFuture is already in the queue. + // The event can't have been set, since this would change the + // waitstate inside the mutex. + // We don't need to update the Waker here, since we assume the + // Future is only ever polled from the same task. If this wouldn't + // hold, true, this wouldn't be safe. + Poll::Pending + } + PollState::Done => { + // We have been woken up by the event. + // This does not guarantee that the event is still set. It could + // have been reset it in the meantime. + Poll::Ready(()) + } + } + } + + fn remove_waiter(&mut self, wait_node: &mut ListNode) { + // WaitForCancellationFuture only needs to get removed if it has been added to + // the wait queue of the Event. This has happened in the PollState::Waiting case. + if let PollState::Waiting = wait_node.state { + if !unsafe { self.waiters.remove(wait_node) } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the WaitQueueEntry got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + wait_node.state = PollState::Done; + } + } +} + +/// A synchronization primitive which can be either in the set or reset state. +/// +/// Tasks can wait for the event to get set by obtaining a Future via `wait`. +/// This Future will get fulfilled when the event has been set. +pub(crate) struct CancellationToken { + inner: Mutex, +} + +// The Event is can be sent to other threads as long as it's not borrowed +unsafe impl Send for CancellationToken {} +// The Event is thread-safe as long as the utilized Mutex is thread-safe +unsafe impl Sync for CancellationToken {} + +impl core::fmt::Debug for CancellationToken { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("CancellationToken").finish() + } +} + +impl CancellationToken { + /// Creates a new CancellationToken in the given state + pub(crate) fn new(is_cancelled: bool) -> CancellationToken { + CancellationToken { + inner: Mutex::new(CancellationTokenState::new(is_cancelled)), + } + } + + /// Sets the cancellation. + /// + /// Setting the cancellation will notify all pending waiters. + pub(crate) fn cancel(&self) { + self.inner.lock().unwrap().cancel() + } + + /// Returns whether the CancellationToken is set + pub(crate) fn is_cancelled(&self) -> bool { + self.inner.lock().unwrap().is_cancelled() + } + + /// Returns a future that gets fulfilled when the token is cancelled + pub(crate) fn wait(&self) -> WaitForCancellationFuture<'_> { + WaitForCancellationFuture { + token: Some(self), + wait_node: ListNode::new(WaitQueueEntry::new()), + } + } + + unsafe fn try_wait( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.inner.lock().unwrap().try_wait(wait_node, cx) + } + + fn remove_waiter(&self, wait_node: &mut ListNode) { + self.inner.lock().unwrap().remove_waiter(wait_node) + } +} + +/// A Future that is resolved once the corresponding CancellationToken has been set +#[must_use = "futures do nothing unless polled"] +pub(crate) struct WaitForCancellationFuture<'a> { + /// The CancellationToken that is associated with this WaitForCancellationFuture + token: Option<&'a CancellationToken>, + /// Node for waiting at the event + wait_node: ListNode, +} + +unsafe impl<'a> Send for WaitForCancellationFuture<'a> {} + +impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitForCancellationFuture").finish() + } +} + +impl<'a> Future for WaitForCancellationFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // It might be possible to use Pin::map_unchecked here instead of the two unsafe APIs. + // However this didn't seem to work for some borrow checker reasons + + // Safety: The next operations are safe, because Pin promises us that + // the address of the wait queue entry inside MutexLocalFuture is stable, + // and we don't move any fields inside the future until it gets dropped. + let mut_self: &mut WaitForCancellationFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let token = mut_self + .token + .expect("polled WaitForCancellationFuture after completion"); + + let poll_res = unsafe { token.try_wait(&mut mut_self.wait_node, cx) }; + + if let Poll::Ready(()) = poll_res { + // The token was set + mut_self.token = None; + } + + poll_res + } +} + +impl<'a> Drop for WaitForCancellationFuture<'a> { + fn drop(&mut self) { + // If this WaitForCancellationFuture has been polled and it was added to the + // wait queue at the event, it must be removed before dropping. + // Otherwise the event would access invalid memory. + if let Some(token) = self.token { + token.remove_waiter(&mut self.wait_node); + } + } +} diff --git a/tokio/src/task/scope/intrusive_double_linked_list.rs b/tokio/src/task/scope/intrusive_double_linked_list.rs new file mode 100644 index 00000000000..afe1e261889 --- /dev/null +++ b/tokio/src/task/scope/intrusive_double_linked_list.rs @@ -0,0 +1,628 @@ +//! An intrusive double linked list of data + +use core::marker::PhantomPinned; +use core::ops::{Deref, DerefMut}; +use core::ptr::null_mut; + +/// A node which carries data of type `T` and is stored in an intrusive list +#[derive(Debug)] +pub(crate) struct ListNode { + /// The previous node in the list. null if there is no previous node. + prev: *mut ListNode, + /// The next node in the list. null if there is no previous node. + next: *mut ListNode, + /// The data which is associated to this list item + data: T, + /// Prevents `ListNode`s from being `Unpin`. They may never be moved, since + /// the list semantics require addresses to be stable. + _pin: PhantomPinned, +} + +impl ListNode { + /// Creates a new node with the associated data + pub(crate) fn new(data: T) -> ListNode { + ListNode:: { + prev: null_mut(), + next: null_mut(), + data, + _pin: PhantomPinned, + } + } +} + +impl Deref for ListNode { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for ListNode { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +/// An intrusive linked list of nodes, where each node carries associated data +/// of type `T`. +#[derive(Debug)] +pub(crate) struct LinkedList { + head: *mut ListNode, + tail: *mut ListNode, +} + +impl LinkedList { + /// Creates an empty linked list + pub(crate) fn new() -> Self { + LinkedList:: { + head: null_mut(), + tail: null_mut(), + } + } + + /// Consumes the list and creates an iterator over the linked list. + /// This function is only safe as long as all pointers which are stored inside + /// the linked list are valid. + #[allow(dead_code)] + pub(crate) unsafe fn into_iter(self) -> LinkedListIterator { + LinkedListIterator { current: self.head } + } + + /// Consumes the list and creates an iterator over the linked list in reverse + /// direction. + /// This function is only safe as long as all pointers which are stored inside + /// the linked list are valid. + pub(crate) unsafe fn into_reverse_iter(self) -> LinkedListReverseIterator { + LinkedListReverseIterator { current: self.tail } + } + + /// Adds an item to the front of the linked list. + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + pub(crate) unsafe fn add_front(&mut self, item: *mut ListNode) { + assert!(!item.is_null(), "Can not add null pointers"); + (*item).next = self.head; + (*item).prev = null_mut(); + if !self.head.is_null() { + (*self.head).prev = item; + } + self.head = item; + if self.tail.is_null() { + self.tail = item; + } + } + + /// Returns the first item in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + #[allow(dead_code)] + pub(crate) fn peek_first(&self) -> *mut ListNode { + self.head + } + + /// Returns the last item in the linked list without removing it from the list + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + /// The returned pointer is only guaranteed to be valid as long as the list + /// is not mutated + #[allow(dead_code)] + pub(crate) fn peek_last(&self) -> *mut ListNode { + self.tail + } + + /// Removes the last item from the linked list and returns it + #[allow(dead_code)] + pub(crate) unsafe fn remove_last(&mut self) -> *mut ListNode { + if self.tail.is_null() { + return null_mut(); + } + + let last = self.tail; + self.tail = (*last).prev; + if !(*last).prev.is_null() { + (*(*last).prev).next = null_mut(); + } else { + // This was the last item in the list + self.head = null_mut(); + } + + (*last).prev = null_mut(); + (*last).next = null_mut(); + last + } + + /// Removes all items from the linked list and returns a LinkedList which + /// contains all the items. + pub(crate) fn take(&mut self) -> LinkedList { + let head = self.head; + let tail = self.tail; + self.head = null_mut(); + self.tail = null_mut(); + LinkedList:: { head, tail } + } + + /// Returns whether the linked list doesn not contain any node + #[allow(dead_code)] + pub(crate) fn is_empty(&self) -> bool { + if !self.head.is_null() { + return false; + } + + assert!(self.tail.is_null()); + true + } + + /// Removes the given item from the linked list. + /// Returns whether the item was removed. + /// The function is only safe as long as valid pointers are stored inside + /// the linked list. + pub(crate) unsafe fn remove(&mut self, item: *mut ListNode) -> bool { + if item.is_null() { + return false; + } + + let prev = (*item).prev; + if prev.is_null() { + // This might be the first item in the list + if self.head != item { + return false; + } + self.head = (*item).next; + } else { + debug_assert_eq!((*prev).next, item); + (*prev).next = (*item).next; + } + + let next = (*item).next; + if next.is_null() { + // This might be the last item in the list + if self.tail != item { + return false; + } + self.tail = (*item).prev; + } else { + debug_assert_eq!((*next).prev, item); + (*next).prev = (*item).prev; + } + + (*item).next = null_mut(); + (*item).prev = null_mut(); + + true + } +} + +/// An iterator over an intrusively linked list +pub(crate) struct LinkedListIterator { + current: *mut ListNode, +} + +impl Iterator for LinkedListIterator { + type Item = *mut ListNode; + + fn next(&mut self) -> Option { + if self.current.is_null() { + return None; + } + + let node = self.current; + // Safety: This is safe as long as the linked list is intact, which was + // already required through the unsafe when creating the iterator. + unsafe { + self.current = (*self.current).next; + } + Some(node) + } +} + +/// An iterator over an intrusively linked list +pub(crate) struct LinkedListReverseIterator { + current: *mut ListNode, +} + +impl Iterator for LinkedListReverseIterator { + type Item = *mut ListNode; + + fn next(&mut self) -> Option { + if self.current.is_null() { + return None; + } + + let node = self.current; + // Safety: This is safe as long as the linked list is intact, which was + // already required through the unsafe when creating the iterator. + unsafe { + self.current = (*self.current).prev; + } + Some(node) + } +} + +#[cfg(test)] +#[cfg(feature = "std")] // Tests make use of Vec at the moment +mod tests { + use super::*; + + unsafe fn collect_list(list: LinkedList) -> Vec { + list.into_iter().map(|item| (*(*item).deref())).collect() + } + + unsafe fn collect_reverse_list(list: LinkedList) -> Vec { + list.into_reverse_iter() + .map(|item| (*(*item).deref())) + .collect() + } + + unsafe fn add_nodes(list: &mut LinkedList, nodes: &mut [&mut ListNode]) { + for node in nodes.iter_mut() { + list.add_front(*node); + } + } + + unsafe fn assert_clean(node: *mut ListNode) { + assert!((*node).next.is_null()); + assert!((*node).prev.is_null()); + } + + #[test] + fn insert_and_iterate() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut setup = |list: &mut LinkedList| { + assert_eq!(true, list.is_empty()); + list.add_front(&mut c); + assert_eq!(31, *(*list.peek_first()).deref()); + assert_eq!(false, list.is_empty()); + list.add_front(&mut b); + assert_eq!(7, *(*list.peek_first()).deref()); + list.add_front(&mut a); + assert_eq!(5, *(*list.peek_first()).deref()); + }; + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + setup(&mut list); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7, 5].to_vec(), items); + } + } + + #[test] + fn add_sorted() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + let mut d = ListNode::new(99); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut b]); + list.add_sorted(&mut a); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut c, &mut a]); + list.add_sorted(&mut b); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut d, &mut b, &mut a]); + list.add_sorted(&mut c); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_list(list); + assert_eq!([5, 7, 31, 99].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + list.add_sorted(&mut d); + let items: Vec = collect_reverse_list(list); + assert_eq!([99, 31, 7, 5].to_vec(), items); + } + } + + #[test] + fn take_items() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let taken = list.take(); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + let taken_items: Vec = collect_list(taken); + assert_eq!([5, 7, 31].to_vec(), taken_items); + } + } + + #[test] + fn peek_last() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + + let last = list.peek_last(); + assert_eq!(31, *(*last).deref()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(7, *(*last).deref()); + list.remove_last(); + + let last = list.peek_last(); + assert_eq!(5, *(*last).deref()); + list.remove_last(); + + let last = list.peek_last(); + assert!(last.is_null()); + } + } + + #[test] + fn remove_last() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(!list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + let removed = list.remove_last(); + assert_clean(removed); + assert!(list.is_empty()); + let items: Vec = collect_reverse_list(list); + assert!(items.is_empty()); + } + } + + #[test] + fn remove_by_address() { + unsafe { + let mut a = ListNode::new(5); + let mut b = ListNode::new(7); + let mut c = ListNode::new(31); + + { + // Remove first + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean(&mut a); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(&mut b as *mut ListNode, list.head); + assert_eq!(&mut c as *mut ListNode, b.next); + assert_eq!(&mut b as *mut ListNode, c.prev); + let items: Vec = collect_list(list); + assert_eq!([7, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean(&mut a); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(&mut c as *mut ListNode, b.next); + assert_eq!(&mut b as *mut ListNode, c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 7].to_vec(), items); + } + + { + // Remove middle + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean(&mut b); + assert_eq!(&mut c as *mut ListNode, a.next); + assert_eq!(&mut a as *mut ListNode, c.prev); + let items: Vec = collect_list(list); + assert_eq!([5, 31].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean(&mut b); + assert_eq!(&mut c as *mut ListNode, a.next); + assert_eq!(&mut a as *mut ListNode, c.prev); + let items: Vec = collect_reverse_list(list); + assert_eq!([31, 5].to_vec(), items); + } + + { + // Remove last + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean(&mut c); + assert!(b.next.is_null()); + assert_eq!(&mut b as *mut ListNode, list.tail); + let items: Vec = collect_list(list); + assert_eq!([5, 7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut c, &mut b, &mut a]); + assert_eq!(true, list.remove(&mut c)); + assert_clean(&mut c); + assert!(b.next.is_null()); + assert_eq!(&mut b as *mut ListNode, list.tail); + let items: Vec = collect_reverse_list(list); + assert_eq!([7, 5].to_vec(), items); + } + + { + // Remove first of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean(&mut a); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(&mut b as *mut ListNode, list.head); + assert_eq!(&mut b as *mut ListNode, list.tail); + assert!(b.next.is_null()); + assert!(b.prev.is_null()); + let items: Vec = collect_list(list); + assert_eq!([7].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean(&mut a); + // a should be no longer there and can't be removed twice + assert_eq!(false, list.remove(&mut a)); + assert_eq!(&mut b as *mut ListNode, list.head); + assert_eq!(&mut b as *mut ListNode, list.tail); + assert!(b.next.is_null()); + assert!(b.prev.is_null()); + let items: Vec = collect_reverse_list(list); + assert_eq!([7].to_vec(), items); + } + + { + // Remove last of two + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean(&mut b); + assert_eq!(&mut a as *mut ListNode, list.head); + assert_eq!(&mut a as *mut ListNode, list.tail); + assert!(a.next.is_null()); + assert!(a.prev.is_null()); + let items: Vec = collect_list(list); + assert_eq!([5].to_vec(), items); + + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut b, &mut a]); + assert_eq!(true, list.remove(&mut b)); + assert_clean(&mut b); + assert_eq!(&mut a as *mut ListNode, list.head); + assert_eq!(&mut a as *mut ListNode, list.tail); + assert!(a.next.is_null()); + assert!(a.prev.is_null()); + let items: Vec = collect_reverse_list(list); + assert_eq!([5].to_vec(), items); + } + + { + // Remove last item + let mut list = LinkedList::new(); + add_nodes(&mut list, &mut [&mut a]); + assert_eq!(true, list.remove(&mut a)); + assert_clean(&mut a); + assert!(list.head.is_null()); + assert!(list.tail.is_null()); + let items: Vec = collect_list(list); + assert!(items.is_empty()); + } + + { + // Remove missing + let mut list = LinkedList::new(); + list.add_front(&mut b); + list.add_front(&mut a); + assert_eq!(false, list.remove(&mut c)); + } + + { + // Remove null + let mut list = LinkedList::new(); + list.add_front(&mut b); + list.add_front(&mut a); + assert_eq!(false, list.remove(null_mut())); + } + } + } +} diff --git a/tokio/src/task/scope/manual_scope.rs b/tokio/src/task/scope/manual_scope.rs new file mode 100644 index 00000000000..ecd918277ff --- /dev/null +++ b/tokio/src/task/scope/manual_scope.rs @@ -0,0 +1,513 @@ +use crate::task::{JoinError, JoinHandle}; +use crate::sync::oneshot::{channel, Receiver, Sender}; +use core::{ + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll, Waker}, +}; +use pin_project_lite::pin_project; +use std::{ + collections::HashMap, + sync::atomic::{AtomicUsize, Ordering}, + sync::{Arc, Mutex, Weak}, +}; + +struct TaskListState { + next_id: u64, + tasks: HashMap, +} + +struct TaskList { + state: Mutex, +} + +impl TaskList { + fn new() -> Self { + Self { + state: Mutex::new(TaskListState { + next_id: 0, + tasks: HashMap::new(), + }), + } + } + + fn cancel_all(&self) { + let mut guard = self.state.lock().unwrap(); + + for (_, task_info) in guard.tasks.iter_mut() { + if let Some(cancel_sender) = task_info.cancel_sender.take() { + let _ = cancel_sender.send(()); + } + } + } + + fn gen_task_id(&self) -> u64 { + let mut guard = self.state.lock().unwrap(); + + let id = guard.next_id; + guard.next_id += 1; + id + } + + fn add_task(&self, id: u64, task_info: TaskInfo) { + let mut guard = self.state.lock().unwrap(); + + guard.tasks.insert(id, task_info); + } + + fn remove_task(&self, id: u64) { + let mut guard = self.state.lock().unwrap(); + + guard.tasks.remove(&id); + } +} + +struct ScopeCore { + /// The list of tasks which have been spawned into the scope + tasks: Arc, + /// Allows to wait until all tasks have completed + wait_group: WaitGroup, +} + +impl ScopeCore { + fn new() -> Self { + Self { + tasks: Arc::new(TaskList::new()), + wait_group: WaitGroup::new(), + } + } +} + +/// A scope +pub struct Scope<'scope, F, Fut, R> { + core: ScopeCore, + wait_group_fut: Option, + scope_fut: Option, + scope_func: Option, + result: Option, + _phantom: PhantomData<&'scope ()>, +} + +unsafe impl<'scope, F, Fut, R> Send for Scope<'scope, F, Fut, R> {} + +impl<'scope, F, Fut, R> Scope<'scope, F, Fut, R> { + pub fn new(scope_func: F) -> Self { + Self { + core: ScopeCore::new(), + wait_group_fut: None, + scope_func: Some(scope_func), + scope_fut: None, + result: None, + _phantom: PhantomData, + } + } +} + +impl<'scope, F, Fut, R> Future for Scope<'scope, F, Fut, R> +where + F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, + Fut: Future + Send, +{ + type Output = R; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut_self = unsafe { self.get_unchecked_mut() }; + + if mut_self.wait_group_fut.is_none() { + mut_self.wait_group_fut = Some(mut_self.core.wait_group.wait_future()); + } + + if mut_self.scope_fut.is_none() && mut_self.scope_func.is_some() { + let scope_func = mut_self.scope_func.take().unwrap(); + let handle = ScopeHandle { + core: &mut mut_self.core as *mut ScopeCore, + _phantom: core::marker::PhantomData, + }; + mut_self.scope_fut = Some(scope_func(handle)); + } + + // First poll the actual scope as long as it is alive + if let Some(fut) = mut_self.scope_fut.as_mut() { + let pinned_fut = unsafe { std::pin::Pin::new_unchecked(fut) }; + match pinned_fut.poll(cx) { + Poll::Ready(res) => { + mut_self.result = Some(res); + mut_self.scope_fut = None; + } + Poll::Pending => { + // As long as the actual future didn't resolve there is no + // need to poll the list of pending tasks. Actually there + // might currently be no pending tasks + return Poll::Pending; + } + } + } + + // Once we are done with the scope poll the WaitGroup + let fut = mut_self + .wait_group_fut + .as_mut() + .expect("Waiter is available"); + let fut = std::pin::Pin::new(fut); + match fut.poll(cx) { + Poll::Ready(()) => { + Poll::Ready(mut_self.result.take().expect("Result must be available")) + } + Poll::Pending => Poll::Pending, + } + } +} + +pin_project! { + /// Allows to wait for a child task to join + pub struct ScopedJoinHandle<'scope, T> { + #[pin] + handle: JoinHandle>, + phantom: core::marker::PhantomData<&'scope ()>, + } +} + +impl<'scope, T> Future for ScopedJoinHandle<'scope, T> { + type Output = Result, JoinError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx) + } +} + +struct TaskInfo { + cancel_sender: Option>, +} + +/// A handle to the scope, which allows to spawn child tasks +pub struct ScopeHandle<'scope> { + core: *mut ScopeCore, + _phantom: core::marker::PhantomData<&'scope ()>, +} + +unsafe impl<'scope> Send for ScopeHandle<'scope> {} +unsafe impl<'scope> Sync for ScopeHandle<'scope> {} + + +impl<'scope> Drop for ScopeHandle<'scope> { + fn drop(&mut self) { + // Cancel all tasks, but do not await them + let core = unsafe { &mut *self.core }; + core.tasks.cancel_all(); + } +} + +impl<'scope> ScopeHandle<'scope> { + /// spawns a task on the scope + pub fn spawn<'inner, T, R>(&'inner self, task: T) -> ScopedJoinHandle<'inner, R> + where + T: Future + Send + 'scope, + R: Send + 'static, + T: 'static, + 'scope: 'inner, + { + let spawn_handle = + crate::runtime::context::spawn_handle().expect("Spawn handle must be available"); + + let (cancel_sender, cancel_receiver) = channel(); + + let core = unsafe { &mut *self.core }; + let task_id = core.tasks.gen_task_id(); + let weak_list = Arc::downgrade(&core.tasks); + + let releaser = core.wait_group.add_task(); + + let child_task = spawn_handle.spawn(async move { + CancellableFuture { + fut: task, + cancel_token: cancel_receiver, + remove_from_task_list_guard: RemoveFromTaskListIfDropped { + task_id, + task_list: weak_list, + _wait_group_releaser: releaser, + }, + } + .await + }); + + let task_info = TaskInfo { + cancel_sender: Some(cancel_sender), + }; + + core.tasks.add_task(task_id, task_info); + + ScopedJoinHandle { + handle: child_task, + phantom: core::marker::PhantomData, + } + } +} + +/// Creates a new scope +// pub fn scope<'scope, F, Fut, R>(scope_func: F) -> Scope<'scope, F, Fut, R> +// where +// F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, +// Fut: Future + Send, +// // 'scope: 'scope, +// { +// Scope::new(scope_func) +// } + +pub async fn scope<'scope, F, Fut, R>(scope_func: F) -> R +where + F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, + Fut: Future + Send + 'scope, +{ + let scope = Scope::new(scope_func); + + scope.await +} + +struct RemoveFromTaskListIfDropped { + task_id: u64, + task_list: Weak, + // Notifies the scope that the task finished if dropped + _wait_group_releaser: WaitGroupReleaser, +} + +impl Drop for RemoveFromTaskListIfDropped { + fn drop(&mut self) { + // The task must have finished if it got dropped, so remove it from the + // task list + if let Some(task_list) = self.task_list.upgrade() { + task_list.remove_task(self.task_id); + } + } +} + +pin_project! { + struct CancellableFuture { + #[pin] + fut: T, + #[pin] + cancel_token: Receiver<()>, + // Auto-removes the task from the scope if this Future gets dropped. + // This must be the last field in the struct. Rusts drop order guarantees + // will lead it to being destroyed last, which means we only signal + // completion to the scope if the future is already destructed + remove_from_task_list_guard: RemoveFromTaskListIfDropped, + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CancellableFutureError { + Cancelled, +} + +impl Future for CancellableFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + + match project.fut.poll(cx) { + Poll::Ready(res) => return Poll::Ready(Ok(res)), + _ => {} + } + + match project.cancel_token.poll(cx) { + Poll::Ready(_) => return Poll::Ready(Err(CancellableFutureError::Cancelled)), + _ => {} + } + + Poll::Pending + } +} + +enum AtomicCancellableFutureResult { + Ok(T), + Cancelled, +} + +pin_project! { + struct AtomicCancellableFuture { + #[pin] + fut: T, + #[pin] + state: AtomicUsize, + } +} + +impl AtomicCancellableFuture +where + T: Future, +{ + fn cancel(&mut self) { + let mut old = self.state.load(Ordering::Acquire); + if old & FLAG_CANCELLED != 0 { + // Cancellation was already requested + return; + } + + while old & FLAG_CANCELLED == 0 { + let next = old | FLAG_CANCELLED; + match self + .state + .compare_exchange(old, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => return, + Err(actual) => old = actual, + } + } + } +} + +const STATE_IDLE: usize = 0; +const STATE_POLLING: usize = 1; +const STATE_COMPLETED: usize = 2; +const FLAG_CANCELLED: usize = 4; + +impl Future for AtomicCancellableFuture +where + T: Future, +{ + type Output = AtomicCancellableFutureResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + + match project.state.compare_exchange( + STATE_IDLE, + STATE_POLLING, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Obtained the poll lock. Poll the inner future + let res = project.fut.poll(cx); + + match res { + Poll::Ready(res) => { + // Since the Future completed, we don't care about + // cancellations anymore. Therefore we can just overwrite + // and CANCELLED flag which might be stored + project.state.store(STATE_COMPLETED, Ordering::Relaxed); + return Poll::Ready(AtomicCancellableFutureResult::Ok(res)); + } + _ => { + const POLLING_AND_CANCELLED: usize = FLAG_CANCELLED | STATE_POLLING; + return match project.state.compare_exchange( + STATE_POLLING, + STATE_IDLE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => Poll::Pending, + Err(actual) => { + project.state.store(FLAG_CANCELLED, Ordering::Relaxed); + Poll::Ready(AtomicCancellableFutureResult::Cancelled) + } + }; + } + } + } + Err(STATE_CANCELLED) => { + return Poll::Ready(AtomicCancellableFutureResult::Cancelled); + } + Err(STATE_COMPLETED) => unreachable!("Future was polled after completion"), + } + } +} + +struct WaitGroupInner { + running_tasks: AtomicUsize, + waiter: Mutex>, +} + +struct WaitGroupWaiter { + wait_group: Arc, +} + +impl Future for WaitGroupWaiter { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.wait_group.running_tasks.load(Ordering::SeqCst) == 0 { + return Poll::Ready(()); + } + + { + let mut guard = self.wait_group.waiter.lock().unwrap(); + *guard = Some(cx.waker().clone()); + } + + if self.wait_group.running_tasks.load(Ordering::SeqCst) == 0 { + // Completed in between. Waker is no longer required + { + let mut guard = self.wait_group.waiter.lock().unwrap(); + *guard = None; + } + return Poll::Ready(()); + } + + Poll::Pending + } +} + +impl WaitGroupInner { + fn new() -> Self { + Self { + running_tasks: AtomicUsize::new(0), + waiter: Mutex::new(None), + } + } + + fn add_task(&self) { + println!("add_task"); + self.running_tasks.fetch_add(1, Ordering::SeqCst); + } + + fn release_task(&self) { + println!("release_task"); + if self.running_tasks.fetch_sub(1, Ordering::SeqCst) != 1 { + return; + } + + let waker = self.waiter.lock().unwrap().take(); + if let Some(waker) = waker { + waker.wake(); + } + } +} + +struct WaitGroup { + inner: Arc, +} + +impl WaitGroup { + fn new() -> Self { + Self { + inner: Arc::new(WaitGroupInner::new()), + } + } + + #[must_use] + fn add_task(&mut self) -> WaitGroupReleaser { + self.inner.add_task(); + WaitGroupReleaser { + inner: self.inner.clone(), + } + } + + fn wait_future(&self) -> WaitGroupWaiter { + WaitGroupWaiter { + wait_group: self.inner.clone(), + } + } +} + +struct WaitGroupReleaser { + inner: Arc, +} + +impl Drop for WaitGroupReleaser { + fn drop(&mut self) { + self.inner.release_task(); + } +} diff --git a/tokio/src/task/scope/mod.rs b/tokio/src/task/scope/mod.rs new file mode 100644 index 00000000000..921e1e4e147 --- /dev/null +++ b/tokio/src/task/scope/mod.rs @@ -0,0 +1,12 @@ +//! The scope module provides the `scope` method, which enables structured concurrency + +mod cancellation_token; +mod intrusive_double_linked_list; +mod wait_group; + +mod scope; +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use scope::{ + scope, scope_with_options, ScopeCancelBehavior, ScopeDropBehavior, ScopeHandle, ScopeOptions, + ScopedJoinHandle, +}; diff --git a/tokio/src/task/scope/scope.rs b/tokio/src/task/scope/scope.rs new file mode 100644 index 00000000000..755868fc0bf --- /dev/null +++ b/tokio/src/task/scope/scope.rs @@ -0,0 +1,354 @@ +use super::{ + cancellation_token::CancellationToken, + wait_group::{WaitGroup, WaitGroupFuture}, +}; +use crate::task::{JoinError, JoinHandle}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +#[derive(Clone)] +struct ScopeState { + cancel_token: Arc, + wait_group: SharedWaitGroup, + options: ScopeOptions, +} + +impl ScopeState { + fn new(options: ScopeOptions) -> Self { + Self { + cancel_token: Arc::new(CancellationToken::new(false)), + wait_group: SharedWaitGroup::new(), + options, + } + } + + fn is_cancelled(&self) -> bool { + self.cancel_token.is_cancelled() + } +} + +pin_project! { + /// Allows to wait for a child task to join + pub struct ScopedJoinHandle<'scope, T> { + #[pin] + handle: JoinHandle>, + phantom: core::marker::PhantomData<&'scope ()>, + } +} + +impl<'scope, T> Future for ScopedJoinHandle<'scope, T> { + // The actual type is Result, JoinError> + // However the cancellation will only happen at the exit of the scope. This + // means in all cases the user still has a handle to the task, the task can + // not be cancelled yet. + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .handle + .poll(cx) + .map(|poll_res| poll_res.map(|poll_ok| poll_ok.expect("Task can not be cancelled"))) + } +} + +/// A handle to the scope, which allows to spawn child tasks +#[derive(Clone)] +pub struct ScopeHandle { + scope: ScopeState, +} + +impl core::fmt::Debug for ScopeHandle { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("ScopeHandle").finish() + } +} + +struct CancelTasksGuard<'a> { + scope: &'a CancellationToken, +} + +impl<'a> Drop for CancelTasksGuard<'a> { + fn drop(&mut self) { + self.scope.cancel(); + } +} + +struct WaitForTasksToJoinGuard<'a> { + wait_group: &'a SharedWaitGroup, + drop_behavior: ScopeDropBehavior, + enabled: bool, +} + +impl<'a> WaitForTasksToJoinGuard<'a> { + fn disarm(&mut self) { + self.enabled = false; + } +} + +impl<'a> Drop for WaitForTasksToJoinGuard<'a> { + fn drop(&mut self) { + if !self.enabled { + return; + } + + match self.drop_behavior { + ScopeDropBehavior::BlockToCompletion => { + let wait_fut = self.wait_group.wait_future(); + + // TODOs: + // - This should not have a futures dependency + // - This might block multithreaded runtimes, since the tasks might need + // the current executor thread to make progress, due to dependening on + // its IO handles. We need to do something along task::block_in_place + // to solve this. + futures::executor::block_on(wait_fut); + } + ScopeDropBehavior::Panic => { + panic!("Scope was dropped before child tasks run to completion"); + } + ScopeDropBehavior::Abort => { + eprintln!("[ERROR] A scope was dropped without being awaited"); + std::process::abort(); + } + ScopeDropBehavior::ContinueTasks => { + // Do nothing + } + } + } +} + +impl ScopeHandle { + /// spawns a task on the scope + pub fn spawn<'inner, T, R>(&'inner self, task: T) -> ScopedJoinHandle<'inner, R> + where + T: Future + Send + 'static, + R: Send + 'static, + T: 'inner, + { + let spawn_handle = + crate::runtime::context::spawn_handle().expect("Spawn handle must be available"); + + let releaser = self.scope.wait_group.add(); + let cancel_token = self.scope.cancel_token.clone(); + let cancel_behavior = self.scope.options.cancel_behavior; + + let child_task = { + spawn_handle.spawn(async move { + // Drop this at the end of the task to signal we are done and unblock + // the WaitGroup + let _wait_group_releaser = releaser; + + if cancel_behavior == ScopeCancelBehavior::CancelChildTasks { + futures::pin_mut!(task); + use futures::FutureExt; + + futures::select! { + _ = cancel_token.wait().fuse() => { + // The child task was cancelled + return Err(CancellableFutureError::Cancelled); + }, + res = task.fuse() => { + return Ok(res); + } + } + } else { + Ok(task.await) + } + }) + }; + + // Since `Scope` is `Sync` and `Send` cancellations can happen at any time + // in case of invalid use. Therefore we only check cancellations once: + // After the task has been spawned. Since the cancellation is already set, + // we need to wait for the task to complete. Then we panic due to invalid + // API usage. + if self.scope.is_cancelled() { + futures::executor::block_on(async { + let _ = child_task.await; + }); + panic!("Spawn on cancelled Scope"); + } + + ScopedJoinHandle { + handle: child_task, + phantom: core::marker::PhantomData, + } + } +} + +/// Defines how a scope will behave if the `Future` it returns get dropped +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ScopeDropBehavior { + /// When a scope is dropped while tasks are outstanding, the current thread + /// will panic. Since this will not wait for child tasks to complete, the + /// child tasks can outlive the parent in this case. + Panic, + /// When a scope is dropped while tasks are outstanding, the process will be + /// aborted. + Abort, + /// When a scope is dropped while tasks are outstanding, the current thread + /// will be blocked until the tasks in the `scope` completed. This option + /// is only available in multithreaded tokio runtimes, and is the default there. + BlockToCompletion, + /// Ignore that the scope got dropped and continue to run the child tasks. + /// Choosing this option will break structured concurrency. It is therefore + /// not recommended to pick the option. + ContinueTasks, +} + +/// Defines how a scope should cancel its child task once the scope is exited +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum ScopeCancelBehavior { + /// Once the scope is exited, all still running child tasks will get cancelled. + /// The cancellation is asynchronous: Tasks will only notice the + /// cancellation the next time they are scheduled. + /// This option is the default option. + CancelChildTasks, + /// Child tasks are allowed to continue to run. + /// This option should only be chosen if it is either guaranteed that child + /// tasks will join on their own, or if the application uses an additional + /// mechanism (like a `CancellationToken`) to signal child tasks to return. + ContinueChildTasks, +} + +/// Advanced configuration options for `scope` +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ScopeOptions { + /// Whether tasks should be cancelled once the scope is exited + pub cancel_behavior: ScopeCancelBehavior, + /// How the scope should behave if it gets dropped instead of being `await`ed + pub drop_behavior: ScopeDropBehavior, +} + +impl Default for ScopeOptions { + fn default() -> Self { + // TODO: We need to identify a mechanism whether this runs in a singlethreaded + // runtime. In that case we should use the panic strategy, since blocking + // is not viable in a singlethreaded context. + const IS_SINGLE_THREADED: bool = false; + + let drop_behavior = if IS_SINGLE_THREADED { + ScopeDropBehavior::Panic + } else { + ScopeDropBehavior::BlockToCompletion + }; + + Self { + cancel_behavior: ScopeCancelBehavior::CancelChildTasks, + drop_behavior, + } + } +} + +/// Creates a task scope with default options. +/// +/// The `scope` allows to spawn child tasks so that the lifetime of child tasks +/// is constrained within the scope. +/// +/// A closure which accepts a [`ScopeHandle`] object and which returns a [`Future`] +/// needs to be passed to `scope`. The [`ScopeHandle`] can be used to spawn child +/// tasks. +/// +/// If the provided `Future` had been polled to completion, all child tasks which +/// have been spawned via the `ScopeHandle` will be cancelled. +/// +/// `scope` returns a [`Future`] which should be awaited. The `await` will only +/// complete once all child tasks that have been spawned via the provided +/// [`ScopeHandle`] have joined. Thereby the `scope` does not allow child tasks +/// to escape their parent task, as long as the `scope` is awaited. +pub async fn scope(scope_func: F) -> R +where + F: FnOnce(ScopeHandle) -> Fut, + Fut: Future + Send, +{ + let options = ScopeOptions::default(); + scope_with_options(options, scope_func).await +} + +/// Creates a [`scope`] with custom options +/// +/// The method behaves like [`scope`], but the cancellation and `Drop` behavior +/// for the [`scope`] are configurable. See [`ScopeOptions`] for details. +pub async fn scope_with_options(options: ScopeOptions, scope_func: F) -> R +where + F: FnOnce(ScopeHandle) -> Fut, + Fut: Future + Send, +{ + let scope_state = ScopeState::new(options); + let wait_fut = scope_state.wait_group.wait_future(); + + let mut wait_for_tasks_guard = WaitForTasksToJoinGuard { + wait_group: &scope_state.wait_group, + enabled: true, + drop_behavior: options.drop_behavior, + }; + + let scoped_result = { + let _cancel_guard = CancelTasksGuard { + scope: &scope_state.cancel_token, + }; + if options.cancel_behavior == ScopeCancelBehavior::ContinueChildTasks { + std::mem::forget(_cancel_guard); + } + + let handle = ScopeHandle { + scope: scope_state.clone(), + }; + let scoped_result = scope_func(handle).await; + + scoped_result + }; + + wait_fut.await; + + // The tasks have completed. We do not need to wait for them to complete + // in the `Drop` guard. + wait_for_tasks_guard.disarm(); + + scoped_result +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CancellableFutureError { + Cancelled, +} + +#[derive(Clone)] +struct SharedWaitGroup { + inner: Arc, +} + +impl SharedWaitGroup { + fn new() -> Self { + Self { + inner: Arc::new(WaitGroup::new()), + } + } + + #[must_use] + fn add(&self) -> WaitGroupReleaser { + self.inner.add(); + WaitGroupReleaser { + inner: self.inner.clone(), + } + } + + fn wait_future<'a>(&'a self) -> WaitGroupFuture<'a> { + self.inner.wait() + } +} + +struct WaitGroupReleaser { + inner: Arc, +} + +impl Drop for WaitGroupReleaser { + fn drop(&mut self) { + self.inner.remove(); + } +} diff --git a/tokio/src/task/scope/scope_state.rs b/tokio/src/task/scope/scope_state.rs new file mode 100644 index 00000000000..c099e76d2f9 --- /dev/null +++ b/tokio/src/task/scope/scope_state.rs @@ -0,0 +1,518 @@ +use crate::task::{JoinError, JoinHandle}; +use crate::sync::oneshot::{channel, Receiver, Sender}; +use core::{ + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll, Waker}, +}; +use pin_project_lite::pin_project; +use std::{ + collections::HashMap, + sync::atomic::{AtomicUsize, Ordering}, + sync::{Arc, Mutex, Weak}, +}; + +struct TaskListState { + next_id: u64, + tasks: HashMap, +} + +struct TaskList { + state: Mutex, +} + +impl TaskList { + fn new() -> Self { + Self { + state: Mutex::new(TaskListState { + next_id: 0, + tasks: HashMap::new(), + }), + } + } + + fn cancel_all(&self) { + let mut guard = self.state.lock().unwrap(); + + for (_, task_info) in guard.tasks.iter_mut() { + if let Some(cancel_sender) = task_info.cancel_sender.take() { + let _ = cancel_sender.send(()); + } + } + } + + fn gen_task_id(&self) -> u64 { + let mut guard = self.state.lock().unwrap(); + + let id = guard.next_id; + guard.next_id += 1; + id + } + + fn add_task(&self, id: u64, task_info: TaskInfo) { + let mut guard = self.state.lock().unwrap(); + + guard.tasks.insert(id, task_info); + } + + fn remove_task(&self, id: u64) { + let mut guard = self.state.lock().unwrap(); + + guard.tasks.remove(&id); + } +} + +struct ScopeCore { + /// The list of tasks which have been spawned into the scope + tasks: Arc, + /// Allows to wait until all tasks have completed + wait_group: WaitGroup, +} + +impl ScopeCore { + fn new() -> Self { + Self { + tasks: Arc::new(TaskList::new()), + wait_group: WaitGroup::new(), + } + } +} + +/// A scope +pub struct Scope<'scope, F, Fut, R> { + core: ScopeCore, + wait_group_fut: Option, + scope_fut: Option, + scope_func: Option, + scope_handle: Option>, + result: Option, + _phantom: PhantomData<&'scope ()>, +} + +unsafe impl<'scope, F, Fut, R> Send for Scope<'scope, F, Fut, R> {} + +impl<'scope, F, Fut, R> Scope<'scope, F, Fut, R> { + pub fn new(scope_func: F) -> Self { + Self { + core: ScopeCore::new(), + wait_group_fut: None, + scope_func: Some(scope_func), + scope_fut: None, + scope_handle: None, + result: None, + _phantom: PhantomData, + } + } +} + +impl<'scope, F, Fut, R> Future for Scope<'scope, F, Fut, R> +where + F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, + Fut: Future + Send, +{ + type Output = R; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut_self = unsafe { self.get_unchecked_mut() }; + + if mut_self.wait_group_fut.is_none() { + mut_self.wait_group_fut = Some(mut_self.core.wait_group.wait_future()); + } + + if mut_self.scope_fut.is_none() && mut_self.scope_func.is_some() { + let scope_func = mut_self.scope_func.take().unwrap(); + mut_self.scope_handle = Some(ScopeHandle { + core: &mut mut_self.core as *mut ScopeCore, + _phantom: core::marker::PhantomData, + }); // unsafe { core::mem::transmute(&mut mut_self.core) }}); + let handle = ScopeHandle { + core: &mut mut_self.core as *mut ScopeCore, + _phantom: core::marker::PhantomData, + }; + mut_self.scope_fut = Some(scope_func(handle)); + } + + // First poll the actual scope as long as it is alive + if let Some(fut) = mut_self.scope_fut.as_mut() { + let pinned_fut = unsafe { std::pin::Pin::new_unchecked(fut) }; + match pinned_fut.poll(cx) { + Poll::Ready(res) => { + mut_self.result = Some(res); + mut_self.scope_fut = None; + } + Poll::Pending => { + // As long as the actual future didn't resolve there is no + // need to poll the list of pending tasks. Actually there + // might currently be no pending tasks + return Poll::Pending; + } + } + } + + // Once we are done with the scope poll the WaitGroup + let fut = mut_self + .wait_group_fut + .as_mut() + .expect("Waiter is available"); + let fut = std::pin::Pin::new(fut); + match fut.poll(cx) { + Poll::Ready(()) => { + Poll::Ready(mut_self.result.take().expect("Result must be available")) + } + Poll::Pending => Poll::Pending, + } + } +} + +pin_project! { + /// Allows to wait for a child task to join + pub struct ScopedJoinHandle<'scope, T> { + #[pin] + handle: JoinHandle>, + phantom: core::marker::PhantomData<&'scope ()>, + } +} + +impl<'scope, T> Future for ScopedJoinHandle<'scope, T> { + type Output = Result, JoinError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().handle.poll(cx) + } +} + +struct TaskInfo { + cancel_sender: Option>, +} + +/// A handle to the scope, which allows to spawn child tasks +pub struct ScopeHandle<'scope> { + core: *mut ScopeCore, + _phantom: core::marker::PhantomData<&'scope ()>, +} + +unsafe impl<'scope> Send for ScopeHandle<'scope> {} +unsafe impl<'scope> Sync for ScopeHandle<'scope> {} + +impl<'scope> Drop for ScopeHandle<'scope> { + fn drop(&mut self) { + // Cancel all tasks, but do not await them + let core = unsafe { &mut *self.core }; + core.tasks.cancel_all(); + } +} + +impl<'scope> ScopeHandle<'scope> { + /// spawns a task on the scope + pub fn spawn<'inner, T, R>(&'inner self, task: T) -> ScopedJoinHandle<'inner, R> + where + T: Future + Send + 'scope, + R: Send + 'static, + T: 'static, + 'scope: 'inner, + { + let spawn_handle = + crate::runtime::context::spawn_handle().expect("Spawn handle must be available"); + + let (cancel_sender, cancel_receiver) = channel(); + + let core = unsafe { &mut *self.core }; + let task_id = core.tasks.gen_task_id(); + let weak_list = Arc::downgrade(&core.tasks); + + let releaser = core.wait_group.add_task(); + + let child_task = spawn_handle.spawn(async move { + CancellableFuture { + fut: task, + cancel_token: cancel_receiver, + remove_from_task_list_guard: RemoveFromTaskListIfDropped { + task_id, + task_list: weak_list, + _wait_group_releaser: releaser, + }, + } + .await + }); + + let task_info = TaskInfo { + cancel_sender: Some(cancel_sender), + }; + + core.tasks.add_task(task_id, task_info); + + ScopedJoinHandle { + handle: child_task, + phantom: core::marker::PhantomData, + } + } +} + +/// Creates a new scope +// pub fn scope<'scope, F, Fut, R>(scope_func: F) -> Scope<'scope, F, Fut, R> +// where +// F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, +// Fut: Future + Send, +// // 'scope: 'scope, +// { +// Scope::new(scope_func) +// } + +pub async fn scope<'scope, F, Fut, R>(scope_func: F) -> R +where + F: FnOnce(ScopeHandle<'scope>) -> Fut + 'scope, + Fut: Future + Send + 'scope, +{ + let scope = Scope::new(scope_func); + + scope.await +} + +struct RemoveFromTaskListIfDropped { + task_id: u64, + task_list: Weak, + // Notifies the scope that the task finished if dropped + _wait_group_releaser: WaitGroupReleaser, +} + +impl Drop for RemoveFromTaskListIfDropped { + fn drop(&mut self) { + // The task must have finished if it got dropped, so remove it from the + // task list + if let Some(task_list) = self.task_list.upgrade() { + task_list.remove_task(self.task_id); + } + } +} + +pin_project! { + struct CancellableFuture { + #[pin] + fut: T, + #[pin] + cancel_token: Receiver<()>, + // Auto-removes the task from the scope if this Future gets dropped. + // This must be the last field in the struct. Rusts drop order guarantees + // will lead it to being destroyed last, which means we only signal + // completion to the scope if the future is already destructed + remove_from_task_list_guard: RemoveFromTaskListIfDropped, + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CancellableFutureError { + Cancelled, +} + +impl Future for CancellableFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + + match project.fut.poll(cx) { + Poll::Ready(res) => return Poll::Ready(Ok(res)), + _ => {} + } + + match project.cancel_token.poll(cx) { + Poll::Ready(_) => return Poll::Ready(Err(CancellableFutureError::Cancelled)), + _ => {} + } + + Poll::Pending + } +} + +enum AtomicCancellableFutureResult { + Ok(T), + Cancelled, +} + +pin_project! { + struct AtomicCancellableFuture { + #[pin] + fut: T, + #[pin] + state: AtomicUsize, + } +} + +impl AtomicCancellableFuture +where + T: Future, +{ + fn cancel(&mut self) { + let mut old = self.state.load(Ordering::Acquire); + if old & FLAG_CANCELLED != 0 { + // Cancellation was already requested + return; + } + + while old & FLAG_CANCELLED == 0 { + let next = old | FLAG_CANCELLED; + match self + .state + .compare_exchange(old, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => return, + Err(actual) => old = actual, + } + } + } +} + +const STATE_IDLE: usize = 0; +const STATE_POLLING: usize = 1; +const STATE_COMPLETED: usize = 2; +const FLAG_CANCELLED: usize = 4; + +impl Future for AtomicCancellableFuture +where + T: Future, +{ + type Output = AtomicCancellableFutureResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + + match project.state.compare_exchange( + STATE_IDLE, + STATE_POLLING, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Obtained the poll lock. Poll the inner future + let res = project.fut.poll(cx); + + match res { + Poll::Ready(res) => { + // Since the Future completed, we don't care about + // cancellations anymore. Therefore we can just overwrite + // and CANCELLED flag which might be stored + project.state.store(STATE_COMPLETED, Ordering::Relaxed); + return Poll::Ready(AtomicCancellableFutureResult::Ok(res)); + } + _ => { + const POLLING_AND_CANCELLED: usize = FLAG_CANCELLED | STATE_POLLING; + return match project.state.compare_exchange( + STATE_POLLING, + STATE_IDLE, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => Poll::Pending, + Err(actual) => { + project.state.store(FLAG_CANCELLED, Ordering::Relaxed); + Poll::Ready(AtomicCancellableFutureResult::Cancelled) + } + }; + } + } + } + Err(STATE_CANCELLED) => { + return Poll::Ready(AtomicCancellableFutureResult::Cancelled); + } + Err(STATE_COMPLETED) => unreachable!("Future was polled after completion"), + } + } +} + +struct WaitGroupInner { + running_tasks: AtomicUsize, + waiter: Mutex>, +} + +struct WaitGroupWaiter { + wait_group: Arc, +} + +impl Future for WaitGroupWaiter { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.wait_group.running_tasks.load(Ordering::SeqCst) == 0 { + return Poll::Ready(()); + } + + { + let mut guard = self.wait_group.waiter.lock().unwrap(); + *guard = Some(cx.waker().clone()); + } + + if self.wait_group.running_tasks.load(Ordering::SeqCst) == 0 { + // Completed in between. Waker is no longer required + { + let mut guard = self.wait_group.waiter.lock().unwrap(); + *guard = None; + } + return Poll::Ready(()); + } + + Poll::Pending + } +} + +impl WaitGroupInner { + fn new() -> Self { + Self { + running_tasks: AtomicUsize::new(0), + waiter: Mutex::new(None), + } + } + + fn add_task(&self) { + println!("add_task"); + self.running_tasks.fetch_add(1, Ordering::SeqCst); + } + + fn release_task(&self) { + println!("release_task"); + if self.running_tasks.fetch_sub(1, Ordering::SeqCst) != 1 { + return; + } + + let waker = self.waiter.lock().unwrap().take(); + if let Some(waker) = waker { + waker.wake(); + } + } +} + +struct WaitGroup { + inner: Arc, +} + +impl WaitGroup { + fn new() -> Self { + Self { + inner: Arc::new(WaitGroupInner::new()), + } + } + + #[must_use] + fn add_task(&mut self) -> WaitGroupReleaser { + self.inner.add_task(); + WaitGroupReleaser { + inner: self.inner.clone(), + } + } + + fn wait_future(&self) -> WaitGroupWaiter { + WaitGroupWaiter { + wait_group: self.inner.clone(), + } + } +} + +struct WaitGroupReleaser { + inner: Arc, +} + +impl Drop for WaitGroupReleaser { + fn drop(&mut self) { + self.inner.release_task(); + } +} diff --git a/tokio/src/task/scope/wait_group.rs b/tokio/src/task/scope/wait_group.rs new file mode 100644 index 00000000000..cd309fa7158 --- /dev/null +++ b/tokio/src/task/scope/wait_group.rs @@ -0,0 +1,271 @@ +//! An asynchronously awaitable WaitGroup which allows to wait for running tasks + +use super::intrusive_double_linked_list::{LinkedList, ListNode}; +use std::{ + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +/// Updates a `Waker` which is stored inside a `Option` to the newest value +/// which is delivered via a `Context`. +fn update_waker_ref(waker_option: &mut Option, cx: &Context<'_>) { + if waker_option + .as_ref() + .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker())) + { + *waker_option = Some(cx.waker().clone()); + } +} + +/// Tracks how the future had interacted with the group +#[derive(PartialEq)] +enum PollState { + /// The task has never interacted with the group. + New, + /// The task was added to the wait queue at the group. + Waiting, + /// The task has been polled to completion. + Done, +} + +/// Tracks the WaitGroupFuture waiting state. +/// Access to this struct is synchronized through the mutex in the WaitGroup. +struct WaitQueueEntry { + /// The task handle of the waiting task + task: Option, + /// Current polling state + state: PollState, +} + +impl WaitQueueEntry { + /// Creates a new WaitQueueEntry + fn new() -> WaitQueueEntry { + WaitQueueEntry { + task: None, + state: PollState::New, + } + } +} + +/// Internal state of the `WaitGroup` +struct GroupState { + count: usize, + waiters: LinkedList, +} + +impl GroupState { + fn new(count: usize) -> GroupState { + GroupState { + count, + waiters: LinkedList::new(), + } + } + + fn add(&mut self) { + self.count += 1; + } + + fn remove(&mut self) { + if self.count == 0 { + return; + } + self.count -= 1; + if self.count != 0 { + return; + } + + // Wakeup all waiters + // This happens inside the lock to make cancellation reliable + // If we would access waiters outside of the lock, the pointers + // may no longer be valid. + // Typically this shouldn't be an issue, since waking a task should + // only move it from the blocked into the ready state and not have + // further side effects. + + let waiters = self.waiters.take(); + + unsafe { + // Use a reverse iterator, so that the oldest waiter gets + // scheduled first + for waiter in waiters.into_reverse_iter() { + if let Some(handle) = (*waiter).task.take() { + handle.wake(); + } + (*waiter).state = PollState::Done; + } + } + } + + /// Checks how many tasks are running. If none are running, this returns + /// `Poll::Ready` immediately. + /// If tasks are running, the WaitGroupFuture gets added to the wait + /// queue at the group, and will be signalled once the tasks completed. + /// This function is only safe as long as the `wait_node`s address is guaranteed + /// to be stable until it gets removed from the queue. + unsafe fn try_wait( + &mut self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + match wait_node.state { + PollState::New => { + if self.count == 0 { + // The group is already signaled + wait_node.state = PollState::Done; + Poll::Ready(()) + } else { + // Added the task to the wait queue + wait_node.task = Some(cx.waker().clone()); + wait_node.state = PollState::Waiting; + self.waiters.add_front(wait_node); + Poll::Pending + } + } + PollState::Waiting => { + // The WaitGroupFuture is already in the queue. + // The group can't have reached 0 tasks, since this would change the + // waitstate inside the mutex. However the caller might have + // passed a different `Waker`. In this case we need to update it. + update_waker_ref(&mut wait_node.task, cx); + Poll::Pending + } + PollState::Done => { + // We have been woken up by the group. + // This does not guarantee that the group still has 0 running tasks. + Poll::Ready(()) + } + } + } + + fn remove_waiter(&mut self, wait_node: &mut ListNode) { + // WaitGroupFuture only needs to get removed if it has been added to + // the wait queue of the WaitGroup. This has happened in the PollState::Waiting case. + if let PollState::Waiting = wait_node.state { + if !unsafe { self.waiters.remove(wait_node) } { + // Panic if the address isn't found. This can only happen if the contract was + // violated, e.g. the WaitQueueEntry got moved after the initial poll. + panic!("Future could not be removed from wait queue"); + } + wait_node.state = PollState::Done; + } + } +} + +/// A synchronization primitive which allows to wait until all tracked tasks +/// have finished. +/// +/// Tasks can wait for tracked tasks to finish by obtaining a Future via `wait`. +/// This Future will get fulfilled when no tasks are running anymore. +pub(crate) struct WaitGroup { + inner: Mutex, +} + +// The Group is can be sent to other threads as long as it's not borrowed +unsafe impl Send for WaitGroup {} +// The Group is thread-safe as long as the utilized Mutex is thread-safe +unsafe impl Sync for WaitGroup {} + +impl core::fmt::Debug for WaitGroup { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroup").finish() + } +} + +impl WaitGroup { + /// Creates a new WaitGroup + pub(crate) fn new() -> WaitGroup { + WaitGroup { + inner: Mutex::new(GroupState::new(0)), + } + } + + /// Adds a pending task to the WaitGroup + pub(crate) fn add(&self) { + self.inner.lock().unwrap().add() + } + + /// Removes a task that has finished from the WaitGroup + pub(crate) fn remove(&self) { + self.inner.lock().unwrap().remove() + } + + /// Returns a future that gets fulfilled when all tracked tasks complete + pub(crate) fn wait(&self) -> WaitGroupFuture<'_> { + WaitGroupFuture { + group: Some(self), + wait_node: ListNode::new(WaitQueueEntry::new()), + } + } + + unsafe fn try_wait( + &self, + wait_node: &mut ListNode, + cx: &mut Context<'_>, + ) -> Poll<()> { + self.inner.lock().unwrap().try_wait(wait_node, cx) + } + + fn remove_waiter(&self, wait_node: &mut ListNode) { + self.inner.lock().unwrap().remove_waiter(wait_node) + } +} + +/// A Future that is resolved once the corresponding WaitGroup has reached +/// 0 active tasks. +#[must_use = "futures do nothing unless polled"] +pub(crate) struct WaitGroupFuture<'a> { + /// The WaitGroup that is associated with this WaitGroupFuture + group: Option<&'a WaitGroup>, + /// Node for waiting at the group + wait_node: ListNode, +} + +// Safety: Futures can be sent between threads as long as the underlying +// group is thread-safe (Sync), which allows to poll/register/unregister from +// a different thread. +unsafe impl<'a> Send for WaitGroupFuture<'a> {} + +impl<'a> core::fmt::Debug for WaitGroupFuture<'a> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroupFuture").finish() + } +} + +impl<'a> Future for WaitGroupFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + // It might be possible to use Pin::map_unchecked here instead of the two unsafe APIs. + // However this didn't seem to work for some borrow checker reasons + + // Safety: The next operations are safe, because Pin promises us that + // the address of the wait queue entry inside WaitGroupFuture is stable, + // and we don't move any fields inside the future until it gets dropped. + let mut_self: &mut WaitGroupFuture<'_> = unsafe { Pin::get_unchecked_mut(self) }; + + let group = mut_self + .group + .expect("polled WaitGroupFuture after completion"); + + let poll_res = unsafe { group.try_wait(&mut mut_self.wait_node, cx) }; + + if let Poll::Ready(()) = poll_res { + mut_self.group = None; + } + + poll_res + } +} + +impl<'a> Drop for WaitGroupFuture<'a> { + fn drop(&mut self) { + // If this WaitGroupFuture has been polled and it was added to the + // wait queue at the group, it must be removed before dropping. + // Otherwise the group would access invalid memory. + if let Some(ev) = self.group { + ev.remove_waiter(&mut self.wait_node); + } + } +} diff --git a/tokio/tests/task_scope.rs b/tokio/tests/task_scope.rs new file mode 100644 index 00000000000..227a779c5a5 --- /dev/null +++ b/tokio/tests/task_scope.rs @@ -0,0 +1,336 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use futures::{select, FutureExt}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::{Duration, Instant}, +}; +use tokio::{ + task::{ + scope, scope_with_options, ScopeCancelBehavior, ScopeDropBehavior, ScopeHandle, + ScopeOptions, ScopedJoinHandle, + }, + time::delay_for, +}; + +#[tokio::test] +async fn unused_scope() { + let scope = scope(|_scope| async {}); + drop(scope); +} + +#[tokio::test] +async fn spawn_and_return_result() { + let result = scope(|scope| { + async move { + let handle = scope.spawn(async { 123u32 }); + handle.await + } + }) + .await; + assert_eq!(123u32, result.unwrap()); +} + +#[derive(Clone)] +struct AtomicFlag(Arc); + +impl AtomicFlag { + fn new() -> Self { + AtomicFlag(Arc::new(AtomicBool::new(false))) + } + + fn is_set(&self) -> bool { + self.0.load(Ordering::Acquire) + } + + fn set(&self) { + self.0.store(true, Ordering::Release); + } +} + +struct SetFlagOnDropGuard { + flag: AtomicFlag, +} + +impl Drop for SetFlagOnDropGuard { + fn drop(&mut self) { + self.flag.set(); + } +} + +#[tokio::test] +async fn cancel_and_wait_for_child_task() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let result = scope(|scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async { + let _guard = SetFlagOnDropGuard { flag: flag_clone }; + loop { + tokio::task::yield_now().await; + } + }); + + handle.await + } + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task was cancelled + assert_eq!(true, flag.is_set()); +} + +#[tokio::test] +async fn run_task_to_completion_if_configured() { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let options = ScopeOptions { + drop_behavior: ScopeDropBehavior::Panic, + cancel_behavior: ScopeCancelBehavior::ContinueChildTasks, + }; + + let result = scope_with_options(options, |scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // This should run to completion - even if it takes longer + delay_for(Duration::from_millis(50)).await; + flag_clone.set(); + }); + + handle.await + } + }) + .await; + assert_eq!(123u32, result.unwrap()); + + // Check that the second task run to completion + assert_eq!(true, flag.is_set()); +} + +#[test] +fn block_until_non_joined_tasks_complete() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let start_time = Instant::now(); + let scope_fut = scope(|scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + flag_clone.set(); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + + assert!(start_time.elapsed() >= Duration::from_millis(100)); + + // Check that the second task run to completion + assert_eq!(true, flag.is_set()); + }); +} + +#[should_panic] +#[test] +fn panic_if_active_scope_is_dropped() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut options = ScopeOptions::default(); + options.drop_behavior = ScopeDropBehavior::Panic; + + let scope_fut = scope_with_options(options, |scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + }); +} + +#[test] +fn child_tasks_can_continue_to_run_if_configured() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let flag = AtomicFlag::new(); + let flag_clone = flag.clone(); + + let mut options = ScopeOptions::default(); + options.drop_behavior = ScopeDropBehavior::ContinueTasks; + + let start_time = Instant::now(); + let scope_fut = scope_with_options(options, |scope| { + async move { + let handle = scope.spawn(async { + delay_for(Duration::from_millis(20)).await; + 123u32 + }); + + scope.spawn(async move { + // Use block_in_place makes the task not cancellable + tokio::task::block_in_place(|| { + std::thread::sleep(Duration::from_millis(100)); + }); + flag_clone.set(); + }); + + handle.await + } + }); + + select! { + _ = scope_fut.fuse() => { + panic!("Scope should not complete"); + }, + _ = delay_for(Duration::from_millis(50)).fuse() => { + // Drop the scope here + }, + }; + + let elapsed = start_time.elapsed(); + assert!(elapsed >= Duration::from_millis(50) && elapsed < Duration::from_millis(100)); + assert_eq!(false, flag.is_set()); + + // Wait until the leaked task run to completion + delay_for(Duration::from_millis(60)).await; + assert_eq!(true, flag.is_set()); + }); +} + +#[test] +fn clone_scope_handles_and_cancel_child() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let drop_flag = AtomicFlag::new(); + let drop_flag_clone = drop_flag.clone(); + let completion_flag = AtomicFlag::new(); + let completion_flag_clone = completion_flag.clone(); + + scope(|scope| { + async move { + let cloned_handle = scope.clone(); + + let join_handle = scope.spawn(async move { + delay_for(Duration::from_millis(20)).await; + // Spawn another task - which is not awaited + let _join_handle = cloned_handle.spawn(async move { + let _guard = SetFlagOnDropGuard { + flag: drop_flag_clone, + }; + + delay_for(Duration::from_millis(50)).await; + // This should not get executed, since the inital task exits before + // and this task gets cancelled. + completion_flag_clone.set(); + }); + }); + + let _ = join_handle.await; + } + }) + .await; + + assert_eq!(true, drop_flag.is_set()); + assert_eq!(false, completion_flag.is_set()); + }); +} + +#[test] +fn clone_scope_handles_and_wait_for_child() { + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let drop_flag = AtomicFlag::new(); + let drop_flag_clone = drop_flag.clone(); + let completion_flag = AtomicFlag::new(); + let completion_flag_clone = completion_flag.clone(); + + let mut options = ScopeOptions::default(); + options.cancel_behavior = ScopeCancelBehavior::ContinueChildTasks; + + let start_time = Instant::now(); + scope_with_options(options, |scope| { + async move { + let cloned_handle = scope.clone(); + + let join_handle = scope.spawn(async move { + delay_for(Duration::from_millis(20)).await; + // Spawn another task - which is not awaited + let _join_handle = cloned_handle.spawn(async move { + let _guard = SetFlagOnDropGuard { + flag: drop_flag_clone, + }; + + delay_for(Duration::from_millis(50)).await; + // This should get executed, since tasks are allowed to run + // to completion. + completion_flag_clone.set(); + }); + }); + + let _ = join_handle.await; + } + }) + .await; + + assert!(start_time.elapsed() >= Duration::from_millis(70)); + + assert_eq!(true, drop_flag.is_set()); + assert_eq!(true, completion_flag.is_set()); + }); +}