From 05fecb95df2edeb0e14529791e2a412aa9bfab11 Mon Sep 17 00:00:00 2001 From: joboet Date: Thu, 7 Nov 2024 17:13:33 +0100 Subject: [PATCH] std: allow after-main use of synchronization primitives By creating an unnamed thread handle when the actual one has already been destroyed, synchronization primitives using thread parking can be used even outside the Rust runtime. This also fixes an inefficiency in the queue-based `RwLock`: if `thread::current` was not initialized yet, it will create a new handle on every parking attempt without initializing `thread::current`. The private `current_or_unnamed` function introduced here fixes this. --- std/src/lib.rs | 3 --- std/src/sync/mpmc/array.rs | 6 ++++-- std/src/sync/mpmc/context.rs | 13 +++++++++---- std/src/sync/mpmc/list.rs | 3 ++- std/src/sync/mpmc/zero.rs | 6 ++++-- std/src/sys/sync/once/queue.rs | 9 +++++---- std/src/sys/sync/rwlock/queue.rs | 6 ++---- std/src/thread/current.rs | 17 +++++++++++++++++ std/src/thread/mod.rs | 15 ++++++++++++--- std/src/thread/scoped.rs | 7 ++++--- 10 files changed, 59 insertions(+), 26 deletions(-) diff --git a/std/src/lib.rs b/std/src/lib.rs index 5b94f036248cb..73fae60baff3d 100644 --- a/std/src/lib.rs +++ b/std/src/lib.rs @@ -174,9 +174,6 @@ //! //! - after-main use of thread-locals, which also affects additional features: //! - [`thread::current()`] -//! - [`thread::scope()`] -//! - [`sync::mpmc`] -//! - [`sync::mpsc`] //! - before-main stdio file descriptors are not guaranteed to be open on unix platforms //! //! diff --git a/std/src/sync/mpmc/array.rs b/std/src/sync/mpmc/array.rs index 2c8ba411f3023..a467237fef152 100644 --- a/std/src/sync/mpmc/array.rs +++ b/std/src/sync/mpmc/array.rs @@ -346,7 +346,8 @@ impl Channel { } // Block the current thread. - let sel = cx.wait_until(deadline); + // SAFETY: the context belongs to the current thread. + let sel = unsafe { cx.wait_until(deadline) }; match sel { Selected::Waiting => unreachable!(), @@ -397,7 +398,8 @@ impl Channel { } // Block the current thread. - let sel = cx.wait_until(deadline); + // SAFETY: the context belongs to the current thread. + let sel = unsafe { cx.wait_until(deadline) }; match sel { Selected::Waiting => unreachable!(), diff --git a/std/src/sync/mpmc/context.rs b/std/src/sync/mpmc/context.rs index 2371d32d4ea0d..51aa7e82e7890 100644 --- a/std/src/sync/mpmc/context.rs +++ b/std/src/sync/mpmc/context.rs @@ -69,7 +69,7 @@ impl Context { inner: Arc::new(Inner { select: AtomicUsize::new(Selected::Waiting.into()), packet: AtomicPtr::new(ptr::null_mut()), - thread: thread::current(), + thread: thread::current_or_unnamed(), thread_id: current_thread_id(), }), } @@ -112,8 +112,11 @@ impl Context { /// Waits until an operation is selected and returns it. /// /// If the deadline is reached, `Selected::Aborted` will be selected. + /// + /// # Safety + /// This may only be called from the thread this `Context` belongs to. #[inline] - pub fn wait_until(&self, deadline: Option) -> Selected { + pub unsafe fn wait_until(&self, deadline: Option) -> Selected { loop { // Check whether an operation has been selected. let sel = Selected::from(self.inner.select.load(Ordering::Acquire)); @@ -126,7 +129,8 @@ impl Context { let now = Instant::now(); if now < end { - thread::park_timeout(end - now); + // SAFETY: guaranteed by caller. + unsafe { self.inner.thread.park_timeout(end - now) }; } else { // The deadline has been reached. Try aborting select. return match self.try_select(Selected::Aborted) { @@ -135,7 +139,8 @@ impl Context { }; } } else { - thread::park(); + // SAFETY: guaranteed by caller. + unsafe { self.inner.thread.park() }; } } } diff --git a/std/src/sync/mpmc/list.rs b/std/src/sync/mpmc/list.rs index 523e6d2f3bb37..d88914f529142 100644 --- a/std/src/sync/mpmc/list.rs +++ b/std/src/sync/mpmc/list.rs @@ -444,7 +444,8 @@ impl Channel { } // Block the current thread. - let sel = cx.wait_until(deadline); + // SAFETY: the context belongs to the current thread. + let sel = unsafe { cx.wait_until(deadline) }; match sel { Selected::Waiting => unreachable!(), diff --git a/std/src/sync/mpmc/zero.rs b/std/src/sync/mpmc/zero.rs index 446881291e68e..577997c07a636 100644 --- a/std/src/sync/mpmc/zero.rs +++ b/std/src/sync/mpmc/zero.rs @@ -190,7 +190,8 @@ impl Channel { drop(inner); // Block the current thread. - let sel = cx.wait_until(deadline); + // SAFETY: the context belongs to the current thread. + let sel = unsafe { cx.wait_until(deadline) }; match sel { Selected::Waiting => unreachable!(), @@ -257,7 +258,8 @@ impl Channel { drop(inner); // Block the current thread. - let sel = cx.wait_until(deadline); + // SAFETY: the context belongs to the current thread. + let sel = unsafe { cx.wait_until(deadline) }; match sel { Selected::Waiting => unreachable!(), diff --git a/std/src/sys/sync/once/queue.rs b/std/src/sys/sync/once/queue.rs index 177d0d7744a6d..87837915b396e 100644 --- a/std/src/sys/sync/once/queue.rs +++ b/std/src/sys/sync/once/queue.rs @@ -93,7 +93,7 @@ const QUEUE_MASK: usize = !STATE_MASK; // use interior mutability. #[repr(align(4))] // Ensure the two lower bits are free to use as state bits. struct Waiter { - thread: Cell>, + thread: Thread, signaled: AtomicBool, next: Cell<*const Waiter>, } @@ -238,7 +238,7 @@ fn wait( return_on_poisoned: bool, ) -> StateAndQueue { let node = &Waiter { - thread: Cell::new(Some(thread::current())), + thread: thread::current_or_unnamed(), signaled: AtomicBool::new(false), next: Cell::new(ptr::null()), }; @@ -277,7 +277,8 @@ fn wait( // can park ourselves, the result could be this thread never gets // unparked. Luckily `park` comes with the guarantee that if it got // an `unpark` just before on an unparked thread it does not park. - thread::park(); + // SAFETY: we retrieved this handle on the current thread above. + unsafe { node.thread.park() } } return state_and_queue.load(Acquire); @@ -309,7 +310,7 @@ impl Drop for WaiterQueue<'_> { let mut queue = to_queue(current); while !queue.is_null() { let next = (*queue).next.get(); - let thread = (*queue).thread.take().unwrap(); + let thread = (*queue).thread.clone(); (*queue).signaled.store(true, Release); thread.unpark(); queue = next; diff --git a/std/src/sys/sync/rwlock/queue.rs b/std/src/sys/sync/rwlock/queue.rs index 51330f8fafe57..bd15f8ee952c9 100644 --- a/std/src/sys/sync/rwlock/queue.rs +++ b/std/src/sys/sync/rwlock/queue.rs @@ -118,7 +118,7 @@ use crate::mem; use crate::ptr::{self, NonNull, null_mut, without_provenance_mut}; use crate::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; use crate::sync::atomic::{AtomicBool, AtomicPtr}; -use crate::thread::{self, Thread, ThreadId}; +use crate::thread::{self, Thread}; /// The atomic lock state. type AtomicState = AtomicPtr<()>; @@ -217,9 +217,7 @@ impl Node { /// Prepare this node for waiting. fn prepare(&mut self) { // Fall back to creating an unnamed `Thread` handle to allow locking in TLS destructors. - self.thread.get_or_init(|| { - thread::try_current().unwrap_or_else(|| Thread::new_unnamed(ThreadId::new())) - }); + self.thread.get_or_init(thread::current_or_unnamed); self.completed = AtomicBool::new(false); } diff --git a/std/src/thread/current.rs b/std/src/thread/current.rs index e6eb90c4c30a5..b9b959f98946b 100644 --- a/std/src/thread/current.rs +++ b/std/src/thread/current.rs @@ -165,6 +165,23 @@ pub(crate) fn try_current() -> Option { } } +/// Gets a handle to the thread that invokes it. If the handle stored in thread- +/// local storage was already destroyed, this creates a new unnamed temporary +/// handle to allow thread parking in nearly all situations. +pub(crate) fn current_or_unnamed() -> Thread { + let current = CURRENT.get(); + if current > DESTROYED { + unsafe { + let current = ManuallyDrop::new(Thread::from_raw(current)); + (*current).clone() + } + } else if current == DESTROYED { + Thread::new_unnamed(id::get_or_init()) + } else { + init_current(current) + } +} + /// Gets a handle to the thread that invokes it. /// /// # Examples diff --git a/std/src/thread/mod.rs b/std/src/thread/mod.rs index 227ee9d64f375..27d7d2395816d 100644 --- a/std/src/thread/mod.rs +++ b/std/src/thread/mod.rs @@ -186,7 +186,7 @@ mod current; #[stable(feature = "rust1", since = "1.0.0")] pub use current::current; -pub(crate) use current::{current_id, drop_current, set_current, try_current}; +pub(crate) use current::{current_id, current_or_unnamed, drop_current, set_current, try_current}; //////////////////////////////////////////////////////////////////////////////// // Thread-local storage @@ -1126,9 +1126,9 @@ pub fn park_timeout_ms(ms: u32) { #[stable(feature = "park_timeout", since = "1.4.0")] pub fn park_timeout(dur: Duration) { let guard = PanicGuard; - // SAFETY: park_timeout is called on the parker owned by this thread. + // SAFETY: park_timeout is called on a handle owned by this thread. unsafe { - current().0.parker().park_timeout(dur); + current().park_timeout(dur); } // No panic occurred, do not abort. forget(guard); @@ -1426,6 +1426,15 @@ impl Thread { unsafe { self.0.parker().park() } } + /// Like the public [`park_timeout`], but callable on any handle. This is + /// used to allow parking in TLS destructors. + /// + /// # Safety + /// May only be called from the thread to which this handle belongs. + pub(crate) unsafe fn park_timeout(&self, dur: Duration) { + unsafe { self.0.parker().park_timeout(dur) } + } + /// Atomically makes the handle's token available if it is not already. /// /// Every thread is equipped with some basic low-level blocking support, via diff --git a/std/src/thread/scoped.rs b/std/src/thread/scoped.rs index b2305b1eda7e1..22cae3dfc21f2 100644 --- a/std/src/thread/scoped.rs +++ b/std/src/thread/scoped.rs @@ -1,4 +1,4 @@ -use super::{Builder, JoinInner, Result, Thread, current, park}; +use super::{Builder, JoinInner, Result, Thread, current_or_unnamed}; use crate::marker::PhantomData; use crate::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; use crate::sync::Arc; @@ -140,7 +140,7 @@ where let scope = Scope { data: Arc::new(ScopeData { num_running_threads: AtomicUsize::new(0), - main_thread: current(), + main_thread: current_or_unnamed(), a_thread_panicked: AtomicBool::new(false), }), env: PhantomData, @@ -152,7 +152,8 @@ where // Wait until all the threads are finished. while scope.data.num_running_threads.load(Ordering::Acquire) != 0 { - park(); + // SAFETY: this is the main thread, the handle belongs to us. + unsafe { scope.data.main_thread.park() }; } // Throw any panic from `f`, or the return value of `f` if no thread panicked.