diff --git a/epaint/src/mutex.rs b/epaint/src/mutex.rs index c189284c097b..ef23337fd918 100644 --- a/epaint/src/mutex.rs +++ b/epaint/src/mutex.rs @@ -6,13 +6,24 @@ /// The lock you get from [`Mutex`]. #[cfg(feature = "multi_threaded")] +#[cfg(not(debug_assertions))] pub use parking_lot::MutexGuard; +/// The lock you get from [`Mutex`]. +#[cfg(feature = "multi_threaded")] +#[cfg(debug_assertions)] +pub struct MutexGuard<'a, T>(parking_lot::MutexGuard<'a, T>, *const ()); + /// Provides interior mutability. Only thread-safe if the `multi_threaded` feature is enabled. #[cfg(feature = "multi_threaded")] #[derive(Default)] pub struct Mutex(parking_lot::Mutex); +#[cfg(debug_assertions)] +thread_local! { + static HELD_LOCKS_TLS: std::cell::RefCell> = std::cell::RefCell::new(std::collections::HashSet::new()); +} + #[cfg(feature = "multi_threaded")] impl Mutex { #[inline(always)] @@ -22,12 +33,21 @@ impl Mutex { #[cfg(debug_assertions)] pub fn lock(&self) -> MutexGuard<'_, T> { - // TODO: detect if we are trying to lock the same mutex from the same thread (bad) - // vs locking it from another thread (fine). - // At the moment we just panic on any double-locking of a mutex (so no multithreaded support in debug builds) - self.0 - .try_lock() - .expect("The Mutex is already locked. Probably a bug") + // Detect if we are recursively taking out a lock on this mutex. + + // use a pointer to the inner data as an id for this lock + let ptr = (&self.0 as *const parking_lot::Mutex<_>).cast::<()>(); + + // Store it in thread local storage while we have a lock guard taken out + HELD_LOCKS_TLS.with(|locks| { + if locks.borrow().contains(&ptr) { + panic!("Recursively locking a Mutex in the same thread is not supported") + } else { + locks.borrow_mut().insert(ptr); + } + }); + + MutexGuard(self.0.lock(), ptr) } #[inline(always)] @@ -37,6 +57,35 @@ impl Mutex { } } +#[cfg(debug_assertions)] +#[cfg(feature = "multi_threaded")] +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + let ptr = self.1; + HELD_LOCKS_TLS.with(|locks| { + locks.borrow_mut().remove(&ptr); + }); + } +} + +#[cfg(debug_assertions)] +#[cfg(feature = "multi_threaded")] +impl std::ops::Deref for MutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(debug_assertions)] +#[cfg(feature = "multi_threaded")] +impl std::ops::DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + // --------------------- /// The lock you get from [`RwLock::read`]. @@ -140,3 +189,41 @@ where Self::new(self.lock().clone()) } } + +#[cfg(test)] +mod tests { + use crate::mutex::Mutex; + use std::time::Duration; + + #[test] + fn lock_two_different_mutexes_single_thread() { + let one = Mutex::new(()); + let two = Mutex::new(()); + let _a = one.lock(); + let _b = two.lock(); + } + + #[test] + #[should_panic] + fn lock_reentry_single_thread() { + let one = Mutex::new(()); + let _a = one.lock(); + let _a2 = one.lock(); // panics + } + + #[test] + fn lock_multiple_threads() { + use std::sync::Arc; + let one = Arc::new(Mutex::new(())); + let our_lock = one.lock(); + let other_thread = { + let one = Arc::clone(&one); + std::thread::spawn(move || { + let _ = one.lock(); + }) + }; + std::thread::sleep(Duration::from_millis(200)); + drop(our_lock); + other_thread.join().unwrap(); + } +}