diff --git a/serial_test/src/code_lock.rs b/serial_test/src/code_lock.rs index cbb2161..19cb89e 100644 --- a/serial_test/src/code_lock.rs +++ b/serial_test/src/code_lock.rs @@ -30,7 +30,6 @@ impl UniqueReentrantMutex { self.locks.parallel_count() } - #[cfg(test)] pub fn is_locked(&self) -> bool { self.locks.is_locked() } @@ -44,6 +43,63 @@ pub(crate) fn global_locks() -> &'static HashMap { LOCKS.get_or_init(HashMap::new) } +/// Check if we are holding a serial lock +/// +/// Can be used to assert that a piece of code can only be called +/// from a test marked `#[serial]`. +/// +/// Example, with `#[serial]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// fn do_something_in_need_of_serialization() { +/// assert!(is_locked_serially(None)); +/// +/// // ... +/// } +/// +/// #[test] +/// # fn unused() {} +/// #[serial] +/// fn main() { +/// do_something_in_need_of_serialization(); +/// } +/// ``` +/// +/// Example, missing `#[serial]`: +/// +/// ```should_panic +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// // #[serial] // <-- missing +/// fn main() { +/// assert!(is_locked_serially(None)); +/// } +/// ``` +/// +/// Example, `#[test(some_key)]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// #[serial(some_key)] +/// fn main() { +/// assert!(is_locked_serially(Some("some_key"))); +/// assert!(!is_locked_serially(None)); +/// } +/// ``` +pub fn is_locked_serially(name: Option<&str>) -> bool { + global_locks() + .get(name.unwrap_or_default()) + .map(|lock| lock.get().is_locked()) + .unwrap_or_default() +} + static MUTEX_ID: AtomicU32 = AtomicU32::new(1); impl UniqueReentrantMutex { @@ -68,3 +124,55 @@ pub(crate) fn check_new_key(name: &str) { Entry::Vacant(v) => v.insert_entry(UniqueReentrantMutex::new_mutex(name)), }; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{local_parallel_core, local_serial_core}; + + const NAME1: &str = "NAME1"; + const NAME2: &str = "NAME2"; + + #[test] + fn assert_serially_locked_without_name() { + local_serial_core(vec![""], None, || { + assert!(is_locked_serially(None)); + assert!(!is_locked_serially(Some("no_such_name"))); + }); + } + + #[test] + fn assert_serially_locked_with_multiple_names() { + local_serial_core(vec![NAME1, NAME2], None, || { + assert!(is_locked_serially(Some(NAME1))); + assert!(is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(Some("no_such_name"))); + assert!(!is_locked_serially(None)); + }); + } + + #[test] + fn assert_serially_locked_when_actually_locked_parallel() { + local_parallel_core(vec![NAME1, NAME2], None, || { + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(Some("no_such_name"))); + assert!(!is_locked_serially(None)); + }); + } + + #[test] + fn assert_serially_locked_outside_serial_lock() { + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(None)); + + local_serial_core(vec![NAME1, NAME2], None, || { + // ... + }); + + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(None)); + } +} diff --git a/serial_test/src/lib.rs b/serial_test/src/lib.rs index d6519f0..a5d71c0 100644 --- a/serial_test/src/lib.rs +++ b/serial_test/src/lib.rs @@ -112,3 +112,5 @@ pub use serial_test_derive::{parallel, serial}; #[cfg(feature = "file_locks")] pub use serial_test_derive::{file_parallel, file_serial}; + +pub use code_lock::is_locked_serially; diff --git a/serial_test/src/rwlock.rs b/serial_test/src/rwlock.rs index 0be0c90..95ed34a 100644 --- a/serial_test/src/rwlock.rs +++ b/serial_test/src/rwlock.rs @@ -49,7 +49,6 @@ impl Locks { } } - #[cfg(test)] pub fn is_locked(&self) -> bool { self.arc.serial.is_locked() }