From 311daba25531af74129036681ffe5635a0f44da9 Mon Sep 17 00:00:00 2001 From: Tatsuya Kawano Date: Tue, 3 Aug 2021 22:55:52 +0800 Subject: [PATCH 1/4] Change the signature of get_or_try_insert_with to avoid redundant allocation Replace `Arc>` in the return type of `get_or_try_insert_with` with `Arc`. --- Cargo.toml | 2 +- src/future/cache.rs | 12 ++---------- src/future/value_initializer.rs | 10 +++------- src/sync/cache.rs | 12 ++---------- src/sync/segment.rs | 6 +----- src/sync/value_initializer.rs | 10 +++------- 6 files changed, 12 insertions(+), 40 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bedd279a..eade3309 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "moka" -version = "0.5.1" +version = "0.6.0" authors = ["Tatsuya Kawano "] edition = "2018" diff --git a/src/future/cache.rs b/src/future/cache.rs index f06f695e..462b286a 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -292,15 +292,11 @@ where /// key even if the method is concurrently called by many async tasks; only one /// of the calls resolves its future, and other calls wait for that future to /// complete. - #[allow(clippy::redundant_allocation)] - // https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation - // `Arc>` in the return type creates an extra heap allocation. - // This will be addressed by Moka v0.6.0. pub async fn get_or_try_insert_with( &self, key: K, init: F, - ) -> Result>> + ) -> Result> where F: Future>>, { @@ -486,16 +482,12 @@ where } } - #[allow(clippy::redundant_allocation)] - // https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation - // `Arc>` in the return type creates an extra heap allocation. - // This will be addressed by Moka v0.6.0. async fn get_or_try_insert_with_hash_and_fun( &self, key: Arc, hash: u64, init: F, - ) -> Result>> + ) -> Result> where F: Future>>, { diff --git a/src/future/value_initializer.rs b/src/future/value_initializer.rs index 7e2ebc72..763fe59a 100644 --- a/src/future/value_initializer.rs +++ b/src/future/value_initializer.rs @@ -6,16 +6,12 @@ use std::{ sync::Arc, }; -type Waiter = Arc>>>>>; +type Waiter = Arc>>>>; -#[allow(clippy::redundant_allocation)] -// https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation pub(crate) enum InitResult { Initialized(V), ReadExisting(V), - // This `Arc>` creates an extra heap allocation. This will be - // addressed by Moka v0.6.0. - InitErr(Arc>), + InitErr(Arc), } pub(crate) struct ValueInitializer { @@ -80,7 +76,7 @@ where Initialized(value) } Err(e) => { - let err = Arc::new(e); + let err = Arc::from(e); *lock = Some(Err(Arc::clone(&err))); self.remove_waiter(&key); InitErr(err) diff --git a/src/sync/cache.rs b/src/sync/cache.rs index 0dd9cccb..270fca11 100644 --- a/src/sync/cache.rs +++ b/src/sync/cache.rs @@ -293,15 +293,11 @@ where /// key even if the method is concurrently called by many threads; only one of /// the calls evaluates its function, and other calls wait for that function to /// complete. - #[allow(clippy::redundant_allocation)] - // https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation - // `Arc>` in the return type creates an extra heap allocation. - // This will be addressed by Moka v0.6.0. pub fn get_or_try_insert_with( &self, key: K, init: F, - ) -> Result>> + ) -> Result> where F: FnOnce() -> Result>, { @@ -310,16 +306,12 @@ where self.get_or_try_insert_with_hash_and_fun(key, hash, init) } - #[allow(clippy::redundant_allocation)] - // https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation - // `Arc>` in the return type creates an extra heap allocation. - // This will be addressed by Moka v0.6.0. pub(crate) fn get_or_try_insert_with_hash_and_fun( &self, key: Arc, hash: u64, init: F, - ) -> Result>> + ) -> Result> where F: FnOnce() -> Result>, { diff --git a/src/sync/segment.rs b/src/sync/segment.rs index 67c6bd77..3a913c3d 100644 --- a/src/sync/segment.rs +++ b/src/sync/segment.rs @@ -155,15 +155,11 @@ where /// key even if the method is concurrently called by many threads; only one of /// the calls evaluates its function, and other calls wait for that function to /// complete. - #[allow(clippy::redundant_allocation)] - // https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation - // `Arc>` in the return type creates an extra heap allocation. - // This will be addressed by Moka v0.6.0. pub fn get_or_try_insert_with( &self, key: K, init: F, - ) -> Result>> + ) -> Result> where F: FnOnce() -> Result>, { diff --git a/src/sync/value_initializer.rs b/src/sync/value_initializer.rs index 44dfd282..dc178a0a 100644 --- a/src/sync/value_initializer.rs +++ b/src/sync/value_initializer.rs @@ -5,16 +5,12 @@ use std::{ sync::Arc, }; -type Waiter = Arc>>>>>; +type Waiter = Arc>>>>; -#[allow(clippy::redundant_allocation)] -// https://rust-lang.github.io/rust-clippy/master/index.html#redundant_allocation pub(crate) enum InitResult { Initialized(V), ReadExisting(V), - // This `Arc>` creates an extra heap allocation. This will be - // addressed by Moka v0.6.0. - InitErr(Arc>), + InitErr(Arc), } pub(crate) struct ValueInitializer { @@ -76,7 +72,7 @@ where Initialized(value) } Err(e) => { - let err = Arc::new(e); + let err = Arc::from(e); *lock = Some(Err(Arc::clone(&err))); self.remove_waiter(&key); InitErr(err) From 8607db7918ec40306cb522bfcfaa6313cd890483 Mon Sep 17 00:00:00 2001 From: Tatsuya Kawano Date: Sun, 8 Aug 2021 13:06:26 +0800 Subject: [PATCH 2/4] Change `get_or_try_insert_with` to return a concrete error type rather than a trait object Now the return type is `Result>` where `E: Error + Send + Sync + 'static`. --- CHANGELOG.md | 9 ++++++ README.md | 6 ++-- src/future/cache.rs | 45 ++++++++++++++++++------------ src/future/value_initializer.rs | 49 +++++++++++++++++++++------------ src/sync/cache.rs | 46 +++++++++++++++++++------------ src/sync/segment.rs | 32 ++++++++++++--------- src/sync/value_initializer.rs | 48 ++++++++++++++++++++------------ 7 files changed, 149 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1a0008c..02ca0dc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Moka — Change Log +## Version 0.6.0 (Unreleased) + +### Changed + +- Change `get_or_try_insert_with` to return a concrete error type rather + than a trait object. ([#23][gh-pull-0023]) + + ## Version 0.5.1 ### Changed @@ -81,6 +89,7 @@ [caffeine-git]: https://github.com/ben-manes/caffeine +[gh-pull-0023]: https://github.com/moka-rs/moka/pull/23/ [gh-pull-0022]: https://github.com/moka-rs/moka/pull/22/ [gh-pull-0020]: https://github.com/moka-rs/moka/pull/20/ [gh-pull-0019]: https://github.com/moka-rs/moka/pull/19/ diff --git a/README.md b/README.md index c4b87c0f..e2b60e44 100644 --- a/README.md +++ b/README.md @@ -61,14 +61,14 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -moka = "0.5" +moka = "0.6" ``` To use the asynchronous cache, enable a crate feature called "future". ```toml [dependencies] -moka = { version = "0.5", features = ["future"] } +moka = { version = "0.6", features = ["future"] } ``` @@ -164,7 +164,7 @@ Here is a similar program to the previous example, but using asynchronous cache // Cargo.toml // // [dependencies] -// moka = { version = "0.5", features = ["future"] } +// moka = { version = "0.6", features = ["future"] } // tokio = { version = "1", features = ["rt-multi-thread", "macros" ] } // futures = "0.3" diff --git a/src/future/cache.rs b/src/future/cache.rs index 462b286a..ad9417c5 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -13,6 +13,7 @@ use crate::{ use crossbeam_channel::{Sender, TrySendError}; use std::{ + any::TypeId, borrow::Borrow, collections::hash_map::RandomState, error::Error, @@ -57,7 +58,7 @@ use std::{ /// // Cargo.toml /// // /// // [dependencies] -/// // moka = { version = "0.5", features = ["future"] } +/// // moka = { version = "0.6", features = ["future"] } /// // tokio = { version = "1", features = ["rt-multi-thread", "macros" ] } /// // futures = "0.3" /// @@ -292,13 +293,10 @@ where /// key even if the method is concurrently called by many async tasks; only one /// of the calls resolves its future, and other calls wait for that future to /// complete. - pub async fn get_or_try_insert_with( - &self, - key: K, - init: F, - ) -> Result> + pub async fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where - F: Future>>, + F: Future>, + E: Error + Send + Sync + 'static, { let hash = self.base.hash(&key); let key = Arc::new(key); @@ -474,7 +472,8 @@ where InitResult::Initialized(v) => { self.insert_with_hash(Arc::clone(&key), hash, v.clone()) .await; - self.value_initializer.remove_waiter(&key); + self.value_initializer + .remove_waiter(&key, TypeId::of::<()>()); v } InitResult::ReadExisting(v) => v, @@ -482,14 +481,15 @@ where } } - async fn get_or_try_insert_with_hash_and_fun( + async fn get_or_try_insert_with_hash_and_fun( &self, key: Arc, hash: u64, init: F, - ) -> Result> + ) -> Result> where - F: Future>>, + F: Future>, + E: Error + Send + Sync + 'static, { if let Some(v) = self.base.get_with_hash(&key, hash) { return Ok(v); @@ -504,7 +504,8 @@ where let hash = self.base.hash(&key); self.insert_with_hash(Arc::clone(&key), hash, v.clone()) .await; - self.value_initializer.remove_waiter(&key); + self.value_initializer + .remove_waiter(&key, TypeId::of::()); Ok(v) } InitResult::ReadExisting(v) => Ok(v), @@ -1020,6 +1021,14 @@ mod tests { #[tokio::test] async fn get_or_try_insert_with() { + use std::sync::Arc; + + #[derive(thiserror::Error, Debug)] + #[error("{}", _0)] + pub struct MyError(String); + + type MyResult = Result>; + let cache = Cache::new(100); const KEY: u32 = 0; @@ -1032,11 +1041,11 @@ mod tests { let cache1 = cache.clone(); async move { // Call `get_or_try_insert_with` immediately. - let v = cache1 + let v: MyResult<_> = cache1 .get_or_try_insert_with(KEY, async { // Wait for 300 ms and return an error. Timer::after(Duration::from_millis(300)).await; - Err("task1 error".into()) + Err(MyError("task1 error".into())) }) .await; assert!(v.is_err()); @@ -1052,7 +1061,7 @@ mod tests { async move { // Wait for 100 ms before calling `get_or_try_insert_with`. Timer::after(Duration::from_millis(100)).await; - let v = cache2 + let v: MyResult<_> = cache2 .get_or_try_insert_with(KEY, async { unreachable!() }) .await; assert!(v.is_err()); @@ -1069,7 +1078,7 @@ mod tests { async move { // Wait for 400 ms before calling `get_or_try_insert_with`. Timer::after(Duration::from_millis(400)).await; - let v = cache3 + let v: MyResult<_> = cache3 .get_or_try_insert_with(KEY, async { // Wait for 300 ms and return an Ok(&str) value. Timer::after(Duration::from_millis(300)).await; @@ -1088,7 +1097,7 @@ mod tests { async move { // Wait for 500 ms before calling `get_or_try_insert_with`. Timer::after(Duration::from_millis(500)).await; - let v = cache4 + let v: MyResult<_> = cache4 .get_or_try_insert_with(KEY, async { unreachable!() }) .await; assert_eq!(v.unwrap(), "task3"); @@ -1105,7 +1114,7 @@ mod tests { async move { // Wait for 800 ms before calling `get_or_try_insert_with`. Timer::after(Duration::from_millis(800)).await; - let v = cache5 + let v: MyResult<_> = cache5 .get_or_try_insert_with(KEY, async { unreachable!() }) .await; assert_eq!(v.unwrap(), "task3"); diff --git a/src/future/value_initializer.rs b/src/future/value_initializer.rs index 763fe59a..3ad8c6ac 100644 --- a/src/future/value_initializer.rs +++ b/src/future/value_initializer.rs @@ -1,21 +1,27 @@ use async_lock::RwLock; use std::{ + any::{Any, TypeId}, error::Error, future::Future, hash::{BuildHasher, Hash}, sync::Arc, }; -type Waiter = Arc>>>>; +type ErrorObject = Arc; +type Waiter = Arc>>>; -pub(crate) enum InitResult { +pub(crate) enum InitResult { Initialized(V), ReadExisting(V), - InitErr(Arc), + InitErr(Arc), } pub(crate) struct ValueInitializer { - waiters: moka_cht::SegmentedHashMap, Waiter, S>, + // TypeId is the type ID of the concrete error type of generic type E in + // try_init_or_read(). We use the type ID as a part of the key to ensure that + // we can always downcast the trait object ErrorObject (in Waiter) into + // its concrete type. + waiters: moka_cht::SegmentedHashMap<(Arc, TypeId), Waiter, S>, } impl ValueInitializer @@ -30,16 +36,17 @@ where } } - pub(crate) async fn init_or_read(&self, key: Arc, init: F) -> InitResult + pub(crate) async fn init_or_read(&self, key: Arc, init: F) -> InitResult where F: Future, { use InitResult::*; + let type_id = TypeId::of::<()>(); let waiter = Arc::new(RwLock::new(None)); let mut lock = waiter.write().await; - match self.try_insert_waiter(&key, &waiter) { + match self.try_insert_waiter(&key, type_id, &waiter) { None => { // Inserted. Resolve the init future. let value = init.await; @@ -58,16 +65,18 @@ where } } - pub(crate) async fn try_init_or_read(&self, key: Arc, init: F) -> InitResult + pub(crate) async fn try_init_or_read(&self, key: Arc, init: F) -> InitResult where - F: Future>>, + F: Future>, + E: Error + Send + Sync + 'static, { use InitResult::*; + let type_id = TypeId::of::(); let waiter = Arc::new(RwLock::new(None)); let mut lock = waiter.write().await; - match self.try_insert_waiter(&key, &waiter) { + match self.try_insert_waiter(&key, type_id, &waiter) { None => { // Inserted. Resolve the init future. match init.await { @@ -76,10 +85,10 @@ where Initialized(value) } Err(e) => { - let err = Arc::from(e); + let err: ErrorObject = Arc::new(e); *lock = Some(Err(Arc::clone(&err))); - self.remove_waiter(&key); - InitErr(err) + self.remove_waiter(&key, type_id); + InitErr(err.downcast().unwrap()) } } } @@ -89,7 +98,7 @@ where std::mem::drop(lock); match &*res.read().await { Some(Ok(value)) => ReadExisting(value.clone()), - Some(Err(e)) => InitErr(Arc::clone(e)), + Some(Err(e)) => InitErr(Arc::clone(e).downcast().unwrap()), None => unreachable!(), } } @@ -97,15 +106,21 @@ where } #[inline] - pub(crate) fn remove_waiter(&self, key: &Arc) { - self.waiters.remove(key); + pub(crate) fn remove_waiter(&self, key: &Arc, type_id: TypeId) { + let key = Arc::clone(key); + self.waiters.remove(&(key, type_id)); } - fn try_insert_waiter(&self, key: &Arc, waiter: &Waiter) -> Option> { + fn try_insert_waiter( + &self, + key: &Arc, + type_id: TypeId, + waiter: &Waiter, + ) -> Option> { let key = Arc::clone(key); let waiter = Arc::clone(waiter); self.waiters - .insert_with_or_modify(key, || waiter, |_, w| Arc::clone(w)) + .insert_with_or_modify((key, type_id), || waiter, |_, w| Arc::clone(w)) } } diff --git a/src/sync/cache.rs b/src/sync/cache.rs index 270fca11..ed3516e2 100644 --- a/src/sync/cache.rs +++ b/src/sync/cache.rs @@ -8,6 +8,7 @@ use crate::{sync::value_initializer::InitResult, PredicateError}; use crossbeam_channel::{Sender, TrySendError}; use std::{ + any::TypeId, borrow::Borrow, collections::hash_map::RandomState, error::Error, @@ -277,7 +278,8 @@ where match self.value_initializer.init_or_read(Arc::clone(&key), init) { InitResult::Initialized(v) => { self.insert_with_hash(Arc::clone(&key), hash, v.clone()); - self.value_initializer.remove_waiter(&key); + self.value_initializer + .remove_waiter(&key, TypeId::of::<()>()); v } InitResult::ReadExisting(v) => v, @@ -293,27 +295,25 @@ where /// key even if the method is concurrently called by many threads; only one of /// the calls evaluates its function, and other calls wait for that function to /// complete. - pub fn get_or_try_insert_with( - &self, - key: K, - init: F, - ) -> Result> + pub fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where - F: FnOnce() -> Result>, + F: FnOnce() -> Result, + E: Error + Send + Sync + 'static, { let hash = self.base.hash(&key); let key = Arc::new(key); self.get_or_try_insert_with_hash_and_fun(key, hash, init) } - pub(crate) fn get_or_try_insert_with_hash_and_fun( + pub(crate) fn get_or_try_insert_with_hash_and_fun( &self, key: Arc, hash: u64, init: F, - ) -> Result> + ) -> Result> where - F: FnOnce() -> Result>, + F: FnOnce() -> Result, + E: Error + Send + Sync + 'static, { if let Some(v) = self.get_with_hash(&key, hash) { return Ok(v); @@ -325,7 +325,8 @@ where { InitResult::Initialized(v) => { self.insert_with_hash(Arc::clone(&key), hash, v.clone()); - self.value_initializer.remove_waiter(&key); + self.value_initializer + .remove_waiter(&key, TypeId::of::()); Ok(v) } InitResult::ReadExisting(v) => Ok(v), @@ -878,7 +879,16 @@ mod tests { #[test] fn get_or_try_insert_with() { - use std::thread::{sleep, spawn}; + use std::{ + sync::Arc, + thread::{sleep, spawn}, + }; + + #[derive(thiserror::Error, Debug)] + #[error("{}", _0)] + pub struct MyError(String); + + type MyResult = Result>; let cache = Cache::new(100); const KEY: u32 = 0; @@ -892,10 +902,10 @@ mod tests { let cache1 = cache.clone(); spawn(move || { // Call `get_or_try_insert_with` immediately. - let v = cache1.get_or_try_insert_with(KEY, || { + let v: MyResult<_> = cache1.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an error. sleep(Duration::from_millis(300)); - Err("thread1 error".into()) + Err(MyError("thread1 error".into())) }); assert!(v.is_err()); }) @@ -910,7 +920,7 @@ mod tests { spawn(move || { // Wait for 100 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(100)); - let v = cache2.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache2.get_or_try_insert_with(KEY, || unreachable!()); assert!(v.is_err()); }) }; @@ -925,7 +935,7 @@ mod tests { spawn(move || { // Wait for 400 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(400)); - let v = cache3.get_or_try_insert_with(KEY, || { + let v: MyResult<_> = cache3.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an Ok(&str) value. sleep(Duration::from_millis(300)); Ok("thread3") @@ -942,7 +952,7 @@ mod tests { spawn(move || { // Wait for 500 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(500)); - let v = cache4.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache4.get_or_try_insert_with(KEY, || unreachable!()); assert_eq!(v.unwrap(), "thread3"); }) }; @@ -957,7 +967,7 @@ mod tests { spawn(move || { // Wait for 800 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(800)); - let v = cache5.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache5.get_or_try_insert_with(KEY, || unreachable!()); assert_eq!(v.unwrap(), "thread3"); }) }; diff --git a/src/sync/segment.rs b/src/sync/segment.rs index 3a913c3d..a75203d1 100644 --- a/src/sync/segment.rs +++ b/src/sync/segment.rs @@ -155,13 +155,10 @@ where /// key even if the method is concurrently called by many threads; only one of /// the calls evaluates its function, and other calls wait for that function to /// complete. - pub fn get_or_try_insert_with( - &self, - key: K, - init: F, - ) -> Result> + pub fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where - F: FnOnce() -> Result>, + F: FnOnce() -> Result, + E: Error + Send + Sync + 'static, { let hash = self.inner.hash(&key); let key = Arc::new(key); @@ -692,7 +689,16 @@ mod tests { #[test] fn get_or_try_insert_with() { - use std::thread::{sleep, spawn}; + use std::{ + sync::Arc, + thread::{sleep, spawn}, + }; + + #[derive(thiserror::Error, Debug)] + #[error("{}", _0)] + pub struct MyError(String); + + type MyResult = Result>; let cache = SegmentedCache::new(100, 4); const KEY: u32 = 0; @@ -706,10 +712,10 @@ mod tests { let cache1 = cache.clone(); spawn(move || { // Call `get_or_try_insert_with` immediately. - let v = cache1.get_or_try_insert_with(KEY, || { + let v: MyResult<_> = cache1.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an error. sleep(Duration::from_millis(300)); - Err("thread1 error".into()) + Err(MyError("thread1 error".into())) }); assert!(v.is_err()); }) @@ -724,7 +730,7 @@ mod tests { spawn(move || { // Wait for 100 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(100)); - let v = cache2.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache2.get_or_try_insert_with(KEY, || unreachable!()); assert!(v.is_err()); }) }; @@ -739,7 +745,7 @@ mod tests { spawn(move || { // Wait for 400 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(400)); - let v = cache3.get_or_try_insert_with(KEY, || { + let v: MyResult<_> = cache3.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an Ok(&str) value. sleep(Duration::from_millis(300)); Ok("thread3") @@ -756,7 +762,7 @@ mod tests { spawn(move || { // Wait for 500 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(500)); - let v = cache4.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache4.get_or_try_insert_with(KEY, || unreachable!()); assert_eq!(v.unwrap(), "thread3"); }) }; @@ -771,7 +777,7 @@ mod tests { spawn(move || { // Wait for 800 ms before calling `get_or_try_insert_with`. sleep(Duration::from_millis(800)); - let v = cache5.get_or_try_insert_with(KEY, || unreachable!()); + let v: MyResult<_> = cache5.get_or_try_insert_with(KEY, || unreachable!()); assert_eq!(v.unwrap(), "thread3"); }) }; diff --git a/src/sync/value_initializer.rs b/src/sync/value_initializer.rs index dc178a0a..3ede11fd 100644 --- a/src/sync/value_initializer.rs +++ b/src/sync/value_initializer.rs @@ -1,20 +1,26 @@ use parking_lot::RwLock; use std::{ + any::{Any, TypeId}, error::Error, hash::{BuildHasher, Hash}, sync::Arc, }; -type Waiter = Arc>>>>; +type ErrorObject = Arc; +type Waiter = Arc>>>; -pub(crate) enum InitResult { +pub(crate) enum InitResult { Initialized(V), ReadExisting(V), - InitErr(Arc), + InitErr(Arc), } pub(crate) struct ValueInitializer { - waiters: moka_cht::SegmentedHashMap, Waiter, S>, + // TypeId is the type ID of the concrete error type of generic type E in + // try_init_or_read(). We use the type ID as a part of the key to ensure that + // we can always downcast the trait object ErrorObject (in Waiter) into + // its concrete type. + waiters: moka_cht::SegmentedHashMap<(Arc, TypeId), Waiter, S>, } impl ValueInitializer @@ -29,13 +35,13 @@ where } } - pub(crate) fn init_or_read(&self, key: Arc, init: impl FnOnce() -> V) -> InitResult { + pub(crate) fn init_or_read(&self, key: Arc, init: impl FnOnce() -> V) -> InitResult { use InitResult::*; let waiter = Arc::new(RwLock::new(None)); let mut lock = waiter.write(); - match self.try_insert_waiter(&key, &waiter) { + match self.try_insert_waiter(&key, TypeId::of::<()>(), &waiter) { None => { // Inserted. Evaluate the init closure. let value = init(); @@ -54,16 +60,18 @@ where } } - pub(crate) fn try_init_or_read(&self, key: Arc, init: F) -> InitResult + pub(crate) fn try_init_or_read(&self, key: Arc, init: F) -> InitResult where - F: FnOnce() -> Result>, + F: FnOnce() -> Result, + E: Error + Send + Sync + 'static, { use InitResult::*; + let type_id = TypeId::of::(); let waiter = Arc::new(RwLock::new(None)); let mut lock = waiter.write(); - match self.try_insert_waiter(&key, &waiter) { + match self.try_insert_waiter(&key, type_id, &waiter) { None => { // Inserted. Evaluate the init closure. match init() { @@ -72,10 +80,10 @@ where Initialized(value) } Err(e) => { - let err = Arc::from(e); + let err: ErrorObject = Arc::new(e); *lock = Some(Err(Arc::clone(&err))); - self.remove_waiter(&key); - InitErr(err) + self.remove_waiter(&key, type_id); + InitErr(err.downcast().unwrap()) } } } @@ -85,7 +93,7 @@ where std::mem::drop(lock); match &*res.read() { Some(Ok(value)) => ReadExisting(value.clone()), - Some(Err(e)) => InitErr(Arc::clone(e)), + Some(Err(e)) => InitErr(Arc::clone(e).downcast().unwrap()), None => unreachable!(), } } @@ -93,15 +101,21 @@ where } #[inline] - pub(crate) fn remove_waiter(&self, key: &Arc) { - self.waiters.remove(key); + pub(crate) fn remove_waiter(&self, key: &Arc, type_id: TypeId) { + let key = Arc::clone(key); + self.waiters.remove(&(key, type_id)); } - fn try_insert_waiter(&self, key: &Arc, waiter: &Waiter) -> Option> { + fn try_insert_waiter( + &self, + key: &Arc, + type_id: TypeId, + waiter: &Waiter, + ) -> Option> { let key = Arc::clone(key); let waiter = Arc::clone(waiter); self.waiters - .insert_with_or_modify(key, || waiter, |_, w| Arc::clone(w)) + .insert_with_or_modify((key, type_id), || waiter, |_, w| Arc::clone(w)) } } From ede5857178d8816ddf28feca75cade435037d0f0 Mon Sep 17 00:00:00 2001 From: Tatsuya Kawano Date: Sun, 8 Aug 2021 13:46:05 +0800 Subject: [PATCH 3/4] Tweak the unit test for get_or_try_insert_with --- src/future/cache.rs | 2 +- src/sync/cache.rs | 2 +- src/sync/segment.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/future/cache.rs b/src/future/cache.rs index ad9417c5..74a12b9f 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -1041,7 +1041,7 @@ mod tests { let cache1 = cache.clone(); async move { // Call `get_or_try_insert_with` immediately. - let v: MyResult<_> = cache1 + let v = cache1 .get_or_try_insert_with(KEY, async { // Wait for 300 ms and return an error. Timer::after(Duration::from_millis(300)).await; diff --git a/src/sync/cache.rs b/src/sync/cache.rs index ed3516e2..6ae07579 100644 --- a/src/sync/cache.rs +++ b/src/sync/cache.rs @@ -902,7 +902,7 @@ mod tests { let cache1 = cache.clone(); spawn(move || { // Call `get_or_try_insert_with` immediately. - let v: MyResult<_> = cache1.get_or_try_insert_with(KEY, || { + let v = cache1.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an error. sleep(Duration::from_millis(300)); Err(MyError("thread1 error".into())) diff --git a/src/sync/segment.rs b/src/sync/segment.rs index a75203d1..e591d428 100644 --- a/src/sync/segment.rs +++ b/src/sync/segment.rs @@ -712,7 +712,7 @@ mod tests { let cache1 = cache.clone(); spawn(move || { // Call `get_or_try_insert_with` immediately. - let v: MyResult<_> = cache1.get_or_try_insert_with(KEY, || { + let v = cache1.get_or_try_insert_with(KEY, || { // Wait for 300 ms and return an error. sleep(Duration::from_millis(300)); Err(MyError("thread1 error".into())) From 226edad72f780c05287804c28e007ba73375a213 Mon Sep 17 00:00:00 2001 From: Tatsuya Kawano Date: Sun, 8 Aug 2021 13:54:18 +0800 Subject: [PATCH 4/4] Change `get_or_try_insert_with` to return a concrete error type rather than a trait object Update the doc comments. --- src/future/cache.rs | 4 ++-- src/sync/cache.rs | 14 +++++++------- src/sync/segment.rs | 16 ++++++++-------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/future/cache.rs b/src/future/cache.rs index 74a12b9f..6947199d 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -291,8 +291,8 @@ where /// /// This method prevents to resolve the init future multiple times on the same /// key even if the method is concurrently called by many async tasks; only one - /// of the calls resolves its future, and other calls wait for that future to - /// complete. + /// of the calls resolves its future (as long as these futures return the same + /// error type), and other calls wait for that future to complete. pub async fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where F: Future>, diff --git a/src/sync/cache.rs b/src/sync/cache.rs index 6ae07579..7f1a9546 100644 --- a/src/sync/cache.rs +++ b/src/sync/cache.rs @@ -255,9 +255,9 @@ where /// Ensures the value of the key exists by inserting the result of the init /// function if not exist, and returns a _clone_ of the value. /// - /// This method prevents to evaluate the init function multiple times on the same + /// This method prevents to evaluate the init closure multiple times on the same /// key even if the method is concurrently called by many threads; only one of - /// the calls evaluates its function, and other calls wait for that function to + /// the calls evaluates its closure, and other calls wait for that closure to /// complete. pub fn get_or_insert_with(&self, key: K, init: impl FnOnce() -> V) -> V { let hash = self.base.hash(&key); @@ -288,13 +288,13 @@ where } /// Try to ensure the value of the key exists by inserting an `Ok` result of the - /// init function if not exist, and returns a _clone_ of the value or the `Err` - /// returned by the function. + /// init closure if not exist, and returns a _clone_ of the value or the `Err` + /// returned by the closure. /// - /// This method prevents to evaluate the init function multiple times on the same + /// This method prevents to evaluate the init closure multiple times on the same /// key even if the method is concurrently called by many threads; only one of - /// the calls evaluates its function, and other calls wait for that function to - /// complete. + /// the calls evaluates its closure (as long as these closures return the same + /// error type), and other calls wait for that closure to complete. pub fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where F: FnOnce() -> Result, diff --git a/src/sync/segment.rs b/src/sync/segment.rs index e591d428..fbf4d67b 100644 --- a/src/sync/segment.rs +++ b/src/sync/segment.rs @@ -133,11 +133,11 @@ where } /// Ensures the value of the key exists by inserting the result of the init - /// function if not exist, and returns a _clone_ of the value. + /// closure if not exist, and returns a _clone_ of the value. /// - /// This method prevents to evaluate the init function multiple times on the same + /// This method prevents to evaluate the init closure multiple times on the same /// key even if the method is concurrently called by many threads; only one of - /// the calls evaluates its function, and other calls wait for that function to + /// the calls evaluates its closure, and other calls wait for that closure to /// complete. pub fn get_or_insert_with(&self, key: K, init: impl FnOnce() -> V) -> V { let hash = self.inner.hash(&key); @@ -148,13 +148,13 @@ where } /// Try to ensure the value of the key exists by inserting an `Ok` result of the - /// init function if not exist, and returns a _clone_ of the value or the `Err` - /// returned by the function. + /// init closure if not exist, and returns a _clone_ of the value or the `Err` + /// returned by the closure. /// - /// This method prevents to evaluate the init function multiple times on the same + /// This method prevents to evaluate the init closure multiple times on the same /// key even if the method is concurrently called by many threads; only one of - /// the calls evaluates its function, and other calls wait for that function to - /// complete. + /// the calls evaluates its closure (as long as these closures return the same + /// error type), and other calls wait for that closure to complete. pub fn get_or_try_insert_with(&self, key: K, init: F) -> Result> where F: FnOnce() -> Result,