diff --git a/spellcheck.dic b/spellcheck.dic index 81430dde286..b9155707000 100644 --- a/spellcheck.dic +++ b/spellcheck.dic @@ -280,6 +280,7 @@ unparks Unparks unreceived unsafety +unsets Unsets unsynchronized untrusted diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 996f0f2d9b4..9bf73b74fbf 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -284,9 +284,11 @@ where } pub(super) fn drop_join_handle_slow(self) { - // Try to unset `JOIN_INTEREST`. This must be done as a first step in + // Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in // case the task concurrently completed. - if self.state().unset_join_interested().is_err() { + let transition = self.state().transition_to_join_handle_dropped(); + + if transition.drop_output { // It is our responsibility to drop the output. This is critical as // the task output may not be `Send` and as such must remain with // the scheduler or `JoinHandle`. i.e. if the output remains in the @@ -301,6 +303,23 @@ where })); } + if transition.drop_waker { + // If the JOIN_WAKER flag is unset at this point, the task is either + // already terminal or not complete so the `JoinHandle` is responsible + // for dropping the waker. + // Safety: + // If the JOIN_WAKER bit is not set the join handle has exclusive + // access to the waker as per rule 2 in task/mod.rs. + // This can only be the case at this point in two scenarios: + // 1. The task completed and the runtime unset `JOIN_WAKER` flag + // after accessing the waker during task completion. So the + // `JoinHandle` is the only one to access the join waker here. + // 2. The task is not completed so the `JoinHandle` was able to unset + // `JOIN_WAKER` bit itself to get mutable access to the waker. + // The runtime will not access the waker when this flag is unset. + unsafe { self.trailer().set_waker(None) }; + } + // Drop the `JoinHandle` reference, possibly deallocating the task self.drop_reference(); } @@ -311,7 +330,6 @@ where fn complete(self) { // The future has completed and its output has been written to the task // stage. We transition from running to complete. - let snapshot = self.state().transition_to_complete(); // We catch panics here in case dropping the future or waking the @@ -320,13 +338,28 @@ where if !snapshot.is_join_interested() { // The `JoinHandle` is not interested in the output of // this task. It is our responsibility to drop the - // output. + // output. The join waker was already dropped by the + // `JoinHandle` before. self.core().drop_future_or_output(); } else if snapshot.is_join_waker_set() { // Notify the waker. Reading the waker field is safe per rule 4 // in task/mod.rs, since the JOIN_WAKER bit is set and the call // to transition_to_complete() above set the COMPLETE bit. self.trailer().wake_join(); + + // Inform the `JoinHandle` that we are done waking the waker by + // unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has + // already been dropped and `JOIN_INTEREST` is unset, then we must + // drop the waker ourselves. + if !self + .state() + .unset_waker_after_complete() + .is_join_interested() + { + // SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so + // we have exclusive access to the waker. + unsafe { self.trailer().set_waker(None) }; + } } })); diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 33f54003d38..15c5a8f4afe 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -94,16 +94,30 @@ //! `JoinHandle` needs to (i) successfully set `JOIN_WAKER` to zero if it is //! not already zero to gain exclusive access to the waker field per rule //! 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` to one. +//! If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped +//! to clear the waker field, only steps (i) and (ii) are relevant. //! //! 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e. -//! the task hasn't yet completed). +//! the task hasn't yet completed). The runtime can change `JOIN_WAKER` only +//! if COMPLETE is one. +//! +//! 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the runtime has +//! exclusive (mutable) access to the waker field. This might happen if the +//! `JoinHandle` gets dropped right after the task completes and the runtime +//! sets the `COMPLETE` bit. In this case the runtime needs the mutable access +//! to the waker field to drop it. //! //! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a //! race. If step (i) fails, then the attempt to write a waker is aborted. If //! step (iii) fails because COMPLETE is set to one by another thread after //! step (i), then the waker field is cleared. Once COMPLETE is one (i.e. //! task has completed), the `JoinHandle` will not modify `JOIN_WAKER`. After the -//! runtime sets COMPLETE to one, it invokes the waker if there is one. +//! runtime sets COMPLETE to one, it invokes the waker if there is one so in this +//! case when a task completes the `JOIN_WAKER` bit implicates to the runtime +//! whether it should invoke the waker or not. After the runtime is done with +//! using the waker during task completion, it unsets the `JOIN_WAKER` bit to give +//! the `JoinHandle` exclusive access again so that it is able to drop the waker +//! at a later point. //! //! All other fields are immutable and can be accessed immutably without //! synchronization by anyone. diff --git a/tokio/src/runtime/task/state.rs b/tokio/src/runtime/task/state.rs index 0fc7bb0329b..037c1c90c61 100644 --- a/tokio/src/runtime/task/state.rs +++ b/tokio/src/runtime/task/state.rs @@ -89,6 +89,12 @@ pub(crate) enum TransitionToNotifiedByRef { Submit, } +#[must_use] +pub(super) struct TransitionToJoinHandleDrop { + pub(super) drop_waker: bool, + pub(super) drop_output: bool, +} + /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { @@ -371,22 +377,45 @@ impl State { .map_err(|_| ()) } - /// Tries to unset the `JOIN_INTEREST` flag. - /// - /// Returns `Ok` if the operation happens before the task transitions to a - /// completed state, `Err` otherwise. - pub(super) fn unset_join_interested(&self) -> UpdateResult { - self.fetch_update(|curr| { - assert!(curr.is_join_interested()); + /// Unsets the `JOIN_INTEREST` flag. If `COMPLETE` is not set, the `JOIN_WAKER` + /// flag is also unset. + /// The returned `TransitionToJoinHandleDrop` indicates whether the `JoinHandle` should drop + /// the output of the future or the join waker after the transition. + pub(super) fn transition_to_join_handle_dropped(&self) -> TransitionToJoinHandleDrop { + self.fetch_update_action(|mut snapshot| { + assert!(snapshot.is_join_interested()); - if curr.is_complete() { - return None; + let mut transition = TransitionToJoinHandleDrop { + drop_waker: false, + drop_output: false, + }; + + snapshot.unset_join_interested(); + + if !snapshot.is_complete() { + // If `COMPLETE` is unset we also unset `JOIN_WAKER` to give the + // `JoinHandle` exclusive access to the waker following rule 6 in task/mod.rs. + // The `JoinHandle` will drop the waker if it has exclusive access + // to drop it. + snapshot.unset_join_waker(); + } else { + // If `COMPLETE` is set the task is completed so the `JoinHandle` is responsible + // for dropping the output. + transition.drop_output = true; } - let mut next = curr; - next.unset_join_interested(); + if !snapshot.is_join_waker_set() { + // If the `JOIN_WAKER` bit is unset and the `JOIN_HANDLE` has exclusive access to + // the join waker and should drop it following this transition. + // This might happen in two situations: + // 1. The task is not completed and we just unset the `JOIN_WAKer` above in this + // function. + // 2. The task is completed. In that case the `JOIN_WAKER` bit was already unset + // by the runtime during completion. + transition.drop_waker = true; + } - Some(next) + (transition, Some(snapshot)) }) } @@ -430,6 +459,16 @@ impl State { }) } + /// Unsets the `JOIN_WAKER` bit unconditionally after task completion. + /// + /// This operation requires the task to be completed. + pub(super) fn unset_waker_after_complete(&self) -> Snapshot { + let prev = Snapshot(self.val.fetch_and(!JOIN_WAKER, AcqRel)); + assert!(prev.is_complete()); + assert!(prev.is_join_waker_set()); + Snapshot(prev.0 & !JOIN_WAKER) + } + pub(super) fn ref_inc(&self) { use std::process; use std::sync::atomic::Ordering::Relaxed; diff --git a/tokio/src/runtime/tests/loom_current_thread.rs b/tokio/src/runtime/tests/loom_current_thread.rs index edda6e49954..fd0a44314f8 100644 --- a/tokio/src/runtime/tests/loom_current_thread.rs +++ b/tokio/src/runtime/tests/loom_current_thread.rs @@ -1,6 +1,6 @@ mod yield_now; -use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::{AtomicUsize, Ordering}; use crate::loom::sync::Arc; use crate::loom::thread; use crate::runtime::{Builder, Runtime}; @@ -9,7 +9,7 @@ use crate::task; use std::future::Future; use std::pin::Pin; use std::sync::atomic::Ordering::{Acquire, Release}; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; fn assert_at_most_num_polls(rt: Arc, at_most_polls: usize) { let (tx, rx) = oneshot::channel(); @@ -106,6 +106,60 @@ fn assert_no_unnecessary_polls() { }); } +#[test] +fn drop_jh_during_schedule() { + unsafe fn waker_clone(ptr: *const ()) -> RawWaker { + let atomic = unsafe { &*(ptr as *const AtomicUsize) }; + atomic.fetch_add(1, Ordering::Relaxed); + RawWaker::new(ptr, &VTABLE) + } + unsafe fn waker_drop(ptr: *const ()) { + let atomic = unsafe { &*(ptr as *const AtomicUsize) }; + atomic.fetch_sub(1, Ordering::Relaxed); + } + unsafe fn waker_nop(_ptr: *const ()) {} + + static VTABLE: RawWakerVTable = + RawWakerVTable::new(waker_clone, waker_drop, waker_nop, waker_drop); + + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + + let mut jh = rt.spawn(async {}); + // Using AbortHandle to increment task refcount. This ensures that the waker is not + // destroyed due to the refcount hitting zero. + let task_refcnt = jh.abort_handle(); + + let waker_refcnt = AtomicUsize::new(1); + { + // Set up the join waker. + use std::future::Future; + use std::pin::Pin; + + // SAFETY: Before `waker_refcnt` goes out of scope, this test asserts that the refcnt + // has dropped to zero. + let join_waker = unsafe { + Waker::from_raw(RawWaker::new( + (&waker_refcnt) as *const AtomicUsize as *const (), + &VTABLE, + )) + }; + + assert!(Pin::new(&mut jh) + .poll(&mut Context::from_waker(&join_waker)) + .is_pending()); + } + assert_eq!(waker_refcnt.load(Ordering::Relaxed), 1); + + let bg_thread = loom::thread::spawn(move || drop(jh)); + rt.block_on(crate::task::yield_now()); + bg_thread.join().unwrap(); + + assert_eq!(waker_refcnt.load(Ordering::Relaxed), 0); + drop(task_refcnt); + }); +} + struct BlockedFuture { rx: Receiver<()>, num_polls: Arc, diff --git a/tokio/tests/rt_handle.rs b/tokio/tests/rt_handle.rs index 9efe9b4bde9..3773b8972af 100644 --- a/tokio/tests/rt_handle.rs +++ b/tokio/tests/rt_handle.rs @@ -2,7 +2,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use std::sync::Arc; use tokio::runtime::Runtime; +use tokio::sync::{mpsc, Barrier}; #[test] #[cfg_attr(panic = "abort", ignore)] @@ -65,6 +67,40 @@ fn interleave_then_enter() { let _enter = rt3.enter(); } +// If the cycle causes a leak, then miri will catch it. +#[test] +fn drop_tasks_with_reference_cycle() { + rt().block_on(async { + let (tx, mut rx) = mpsc::channel(1); + + let barrier = Arc::new(Barrier::new(3)); + let barrier_a = barrier.clone(); + let barrier_b = barrier.clone(); + + let a = tokio::spawn(async move { + let b = rx.recv().await.unwrap(); + + // Poll the JoinHandle once. This registers the waker. + // The other task cannot have finished at this point due to the barrier below. + futures::future::select(b, std::future::ready(())).await; + + barrier_a.wait().await; + }); + + let b = tokio::spawn(async move { + // Poll the JoinHandle once. This registers the waker. + // The other task cannot have finished at this point due to the barrier below. + futures::future::select(a, std::future::ready(())).await; + + barrier_b.wait().await; + }); + + tx.send(b).await.unwrap(); + + barrier.wait().await; + }); +} + #[cfg(tokio_unstable)] mod unstable { use super::*;