Skip to content

Commit

Permalink
Add optionally_get_with method for sync cache
Browse files Browse the repository at this point in the history
  • Loading branch information
LMJW committed Oct 23, 2022
1 parent 22419d7 commit a256c42
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/future/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2597,7 +2597,6 @@ mod tests {

#[tokio::test]
async fn optionally_get_with() {

let cache = Cache::new(100);
const KEY: u32 = 0;

Expand Down Expand Up @@ -2665,7 +2664,9 @@ mod tests {
async move {
// Wait for 500 ms before calling `try_get_with`.
Timer::after(Duration::from_millis(500)).await;
let v = cache4.optionally_get_with(KEY, async { unreachable!() }).await;
let v = cache4
.optionally_get_with(KEY, async { unreachable!() })
.await;
assert_eq!(v.unwrap(), "task3");
}
};
Expand Down
5 changes: 2 additions & 3 deletions src/future/value_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ where

// This closure will be called after the init closure has returned a value.
// It will convert the returned value (from init) into an InitResult.
let post_init = |key, value: Option<V>, mut guard: WaiterGuard<'_, K, V, S>| match value
{
let post_init = |key, value: Option<V>, mut guard: WaiterGuard<'_, K, V, S>| match value {
Some(value) => {
guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone())));
InitResult::Initialized(value)
Expand All @@ -182,7 +181,7 @@ where
// `value` can be either `Some` or `None`. For `None` case, without
// change the existing API too much, we will need to convert `None`
// to Arc<E> here. `Infalliable` could not be instantiated. So it
// might be good to use an empty struct to indicate the error type.
// might be good to use an empty struct to indicate the error type.
let err: ErrorObject = Arc::new(OptionallyNone);
guard.set_waiter_value(WaiterValue::Ready(Err(Arc::clone(&err))));
self.remove_waiter(key, type_id);
Expand Down
4 changes: 4 additions & 0 deletions src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ pub trait ConcurrentCacheExt<K, V> {
/// Performs any pending maintenance operations needed by the cache.
fn sync(&self);
}

// Empty internal struct to be used in optionally_get_with to represent the None
// results.
struct OptionallyNone;
256 changes: 255 additions & 1 deletion src/sync/cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
value_initializer::{InitResult, ValueInitializer},
CacheBuilder, ConcurrentCacheExt,
CacheBuilder, ConcurrentCacheExt, OptionallyNone,
};
use crate::{
common::{
Expand Down Expand Up @@ -1191,6 +1191,131 @@ where
}
}

/// Try to ensure the value of the key exists by inserting an `Some` result of
/// the init closure if not exist, and returns a _clone_ of the value or `None`
/// returned by the closure.
///
/// This method prevents to evaluate the init closure multiple times on the same
/// key even if the method is concurrently called by many threads; only one of
/// the calls evaluates its closure (as long as these closures return the same
/// Option type), and other calls wait for that closure to complete.
///
/// # Example
///
/// ```rust
/// use moka::sync::Cache;
/// use std::{path::Path, time::Duration, thread};
///
/// /// This function tries to get the file size in bytes.
/// fn get_file_size(thread_id: u8, path: impl AsRef<Path>) -> Option<u64> {
/// println!("get_file_size() called by thread {}.", thread_id);
/// std::fs::metadata(path).ok().and_then(|m|Some(m.len()))
/// }
///
/// let cache = Cache::new(100);
///
/// // Spawn four threads.
/// let threads: Vec<_> = (0..4_u8)
/// .map(|thread_id| {
/// let my_cache = cache.clone();
/// thread::spawn(move || {
/// println!("Thread {} started.", thread_id);
///
/// // Try to insert and get the value for key1. Although all four
/// // threads will call `optionally_get_with` at the same time,
/// // get_file_size() must be called only once.
/// let value = my_cache.optionally_get_with(
/// "key1",
/// || get_file_size(thread_id, "./Cargo.toml"),
/// );
///
/// // Ensure the value exists now.
/// assert!(value.is_some());
/// thread::sleep(Duration::from_millis(10));
/// assert!(my_cache.get(&"key1").is_some());
///
/// println!(
/// "Thread {} got the value. (len: {})",
/// thread_id,
/// value.unwrap()
/// );
/// })
/// })
/// .collect();
///
/// // Wait all threads to complete.
/// threads
/// .into_iter()
/// .for_each(|t| t.join().expect("Thread failed"));
/// ```
///
/// **Result**
///
/// - `get_file_size()` was called exactly once by thread 0.
/// - Other threads were blocked until thread 0 inserted the value.
///
/// ```console
/// Thread 0 started.
/// Thread 1 started.
/// Thread 2 started.
/// get_file_size() called by thread 0.
/// Thread 3 started.
/// Thread 2 got the value. (len: 1466)
/// Thread 0 got the value. (len: 1466)
/// Thread 1 got the value. (len: 1466)
/// Thread 3 got the value. (len: 1466)
/// ```
///
/// # Panics
///
/// This method panics when the `init` closure has panicked. When it happens,
/// only the caller whose `init` closure panicked will get the panic (e.g. only
/// thread 1 in the above sample). If there are other calls in progress (e.g.
/// thread 0, 2 and 3 above), this method will restart and resolve one of the
/// remaining `init` closure.
///
pub fn optionally_get_with<F>(&self, key: K, init: F) -> Option<V>
where
F: FnOnce() -> Option<V>,
{
let hash = self.base.hash(&key);
let key = Arc::new(key);
self.get_or_optionally_insert_with_hash_and_fun(key, hash, init)
}

pub(super) fn get_or_optionally_insert_with_hash_and_fun<F>(
&self,
key: Arc<K>,
hash: u64,
init: F,
) -> Option<V>
where
F: FnOnce() -> Option<V>,
{
let res = self.get_with_hash(&key, hash);
if res.is_some() {
return res;
}

match self
.value_initializer
.optionally_init_or_read(Arc::clone(&key), init)
{
InitResult::Initialized(v) => {
self.insert_with_hash(Arc::clone(&key), hash, v.clone());
self.value_initializer
.remove_waiter(&key, TypeId::of::<OptionallyNone>());
crossbeam_epoch::pin().flush();
Some(v)
}
InitResult::ReadExisting(v) => Some(v),
InitResult::InitErr(_) => {
crossbeam_epoch::pin().flush();
None
}
}
}

/// Inserts a key-value pair into the cache.
///
/// If the cache has this key present, the value is updated.
Expand Down Expand Up @@ -2556,6 +2681,135 @@ mod tests {
}
}

#[test]
fn optionally_get_with() {
use std::thread::{sleep, spawn};

let cache = Cache::new(100);
const KEY: u32 = 0;

// This test will run eight threads:
//
// Thread1 will be the first thread to call `optionally_get_with` for a key, so its
// init closure will be evaluated and then an error will be returned. Nothing
// will be inserted to the cache.
let thread1 = {
let cache1 = cache.clone();
spawn(move || {
// Call `optionally_get_with` immediately.
let v = cache1.optionally_get_with(KEY, || {
// Wait for 300 ms and return an error.
sleep(Duration::from_millis(300));
None
});
assert!(v.is_none());
})
};

// Thread2 will be the second thread to call `optionally_get_with` for the same key,
// so its init closure will not be evaluated. Once thread1's init closure
// finishes, it will get the same error value returned by thread1's init
// closure.
let thread2 = {
let cache2 = cache.clone();
spawn(move || {
// Wait for 100 ms before calling `optionally_get_with`.
sleep(Duration::from_millis(100));
let v = cache2.optionally_get_with(KEY, || unreachable!());
assert!(v.is_none());
})
};

// Thread3 will be the third thread to call `get_with` for the same key. By
// the time it calls, thread1's init closure should have finished already,
// but the key still does not exist in the cache. So its init closure will be
// evaluated and then an okay &str value will be returned. That value will be
// inserted to the cache.
let thread3 = {
let cache3 = cache.clone();
spawn(move || {
// Wait for 400 ms before calling `optionally_get_with`.
sleep(Duration::from_millis(400));
let v = cache3.optionally_get_with(KEY, || {
// Wait for 300 ms and return an Ok(&str) value.
sleep(Duration::from_millis(300));
Some("thread3")
});
assert_eq!(v.unwrap(), "thread3");
})
};

// thread4 will be the fourth thread to call `optionally_get_with` for the same
// key. So its init closure will not be evaluated. Once thread3's init
// closure finishes, it will get the same okay &str value.
let thread4 = {
let cache4 = cache.clone();
spawn(move || {
// Wait for 500 ms before calling `optionally_get_with`.
sleep(Duration::from_millis(500));
let v = cache4.optionally_get_with(KEY, || unreachable!());
assert_eq!(v.unwrap(), "thread3");
})
};

// Thread5 will be the fifth thread to call `optionally_get_with` for the same
// key. So its init closure will not be evaluated. By the time it calls,
// thread3's init closure should have finished already, so its init closure
// will not be evaluated and will get the value insert by thread3's init
// closure immediately.
let thread5 = {
let cache5 = cache.clone();
spawn(move || {
// Wait for 800 ms before calling `optionally_get_with`.
sleep(Duration::from_millis(800));
let v = cache5.optionally_get_with(KEY, || unreachable!());
assert_eq!(v.unwrap(), "thread3");
})
};

// Thread6 will call `get` for the same key. It will call when thread1's init
// closure is still running, so it will get none for the key.
let thread6 = {
let cache6 = cache.clone();
spawn(move || {
// Wait for 200 ms before calling `get`.
sleep(Duration::from_millis(200));
let maybe_v = cache6.get(&KEY);
assert!(maybe_v.is_none());
})
};

// Thread7 will call `get` for the same key. It will call after thread1's init
// closure finished with an error. So it will get none for the key.
let thread7 = {
let cache7 = cache.clone();
spawn(move || {
// Wait for 400 ms before calling `get`.
sleep(Duration::from_millis(400));
let maybe_v = cache7.get(&KEY);
assert!(maybe_v.is_none());
})
};

// Thread8 will call `get` for the same key. It will call after thread3's init
// closure finished, so it will get the value insert by thread3's init closure.
let thread8 = {
let cache8 = cache.clone();
spawn(move || {
// Wait for 800 ms before calling `get`.
sleep(Duration::from_millis(800));
let maybe_v = cache8.get(&KEY);
assert_eq!(maybe_v, Some("thread3"));
})
};

for t in vec![
thread1, thread2, thread3, thread4, thread5, thread6, thread7, thread8,
] {
t.join().expect("Failed to join");
}
}

#[test]
// https://github.com/moka-rs/moka/issues/43
fn handle_panic_in_get_with() {
Expand Down
Loading

0 comments on commit a256c42

Please sign in to comment.