diff --git a/server/src/auth/token_validator.rs b/server/src/auth/token_validator.rs index 1177997d..c102f66e 100644 --- a/server/src/auth/token_validator.rs +++ b/server/src/auth/token_validator.rs @@ -12,7 +12,7 @@ pub struct TokenValidator { impl TokenValidator { async fn get_unknown_and_known_tokens( - &mut self, + &self, tokens: Vec, ) -> EdgeResult<(Vec, Vec)> { let tokens_with_valid_format: Vec = tokens @@ -32,7 +32,7 @@ impl TokenValidator { } } - pub async fn register_token(&mut self, token: String) -> EdgeResult { + pub async fn register_token(&self, token: String) -> EdgeResult { Ok(self .register_tokens(vec![token]) .await? @@ -41,7 +41,7 @@ impl TokenValidator { .clone()) } - pub async fn register_tokens(&mut self, tokens: Vec) -> EdgeResult> { + pub async fn register_tokens(&self, tokens: Vec) -> EdgeResult> { let (unknown_tokens, known_tokens) = self.get_unknown_and_known_tokens(tokens).await?; if unknown_tokens.is_empty() { Ok(known_tokens) @@ -94,7 +94,6 @@ mod tests { use chrono::Duration; use serde::{Deserialize, Serialize}; use std::sync::Arc; - use tokio::sync::RwLock; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct EdgeTokens { @@ -132,7 +131,7 @@ mod tests { #[tokio::test] pub async fn can_validate_tokens() { - let test_provider = Arc::new(RwLock::new(MemoryProvider::default())); + let test_provider = Arc::new(MemoryProvider::default()); let facade = Arc::new(DataSourceFacade { features_refresh_interval: Some(Duration::seconds(1)), token_source: test_provider.clone(), @@ -149,7 +148,7 @@ mod tests { crate::http::unleash_client::UnleashClient::new(srv.url("/").as_str(), None) .expect("Couldn't build client"); - let mut validation_holder = super::TokenValidator { + let validation_holder = super::TokenValidator { unleash_client: Arc::new(unleash_client), edge_source: source, edge_sink: sink, @@ -163,8 +162,6 @@ mod tests { .await .expect("Couldn't register tokens"); let known_tokens = test_provider - .read() - .await .get_tokens() .await .expect("Couldn't get tokens"); @@ -180,7 +177,7 @@ mod tests { #[tokio::test] pub async fn tokens_with_wrong_format_is_not_included() { - let test_provider = Arc::new(RwLock::new(MemoryProvider::default())); + let test_provider = Arc::new(MemoryProvider::default()); let facade = Arc::new(DataSourceFacade { features_refresh_interval: Some(Duration::seconds(1)), feature_source: test_provider.clone(), @@ -195,7 +192,7 @@ mod tests { let unleash_client = crate::http::unleash_client::UnleashClient::new(srv.url("/").as_str(), None) .expect("Couldn't build client"); - let mut validation_holder = super::TokenValidator { + let validation_holder = super::TokenValidator { unleash_client: Arc::new(unleash_client), edge_source: source, edge_sink: sink, diff --git a/server/src/data_sources/builder.rs b/server/src/data_sources/builder.rs index 3909b7c2..a261307c 100644 --- a/server/src/data_sources/builder.rs +++ b/server/src/data_sources/builder.rs @@ -39,7 +39,7 @@ fn build_offline(offline_args: OfflineArgs) -> EdgeResult> { } fn build_memory(features_refresh_interval_seconds: Duration) -> EdgeResult { - let data_source = Arc::new(RwLock::new(MemoryProvider::new())); + let data_source = Arc::new(MemoryProvider::new()); let facade = Arc::new(DataSourceFacade { features_refresh_interval: Some(features_refresh_interval_seconds), token_source: data_source.clone(), @@ -58,7 +58,7 @@ fn build_redis( redis_url: String, features_refresh_interval_seconds: Duration, ) -> EdgeResult { - let data_source = Arc::new(RwLock::new(RedisProvider::new(&redis_url)?)); + let data_source = Arc::new(RedisProvider::new(&redis_url)?); let facade = Arc::new(DataSourceFacade { token_source: data_source.clone(), feature_source: data_source.clone(), @@ -93,7 +93,7 @@ pub async fn build_source_and_sink(args: CliArgs) -> EdgeResult EdgeArg::InMemory => build_memory(refresh_interval), }?; - let mut token_validator = TokenValidator { + let token_validator = TokenValidator { unleash_client: Arc::new(unleash_client.clone()), edge_source: source.clone(), edge_sink: sink.clone(), diff --git a/server/src/data_sources/memory_provider.rs b/server/src/data_sources/memory_provider.rs index 6104f3ff..87d24e7a 100644 --- a/server/src/data_sources/memory_provider.rs +++ b/server/src/data_sources/memory_provider.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use crate::types::TokenRefresh; use crate::types::{EdgeResult, EdgeToken}; use actix_web::http::header::EntityTag; @@ -13,8 +11,8 @@ use super::repository::{DataSink, DataSource}; #[derive(Debug, Clone)] pub struct MemoryProvider { data_store: DashMap, - token_store: HashMap, - tokens_to_refresh: HashMap, + token_store: DashMap, + tokens_to_refresh: DashMap, } fn key(token: &EdgeToken) -> String { @@ -31,8 +29,8 @@ impl MemoryProvider { pub fn new() -> Self { Self { data_store: DashMap::new(), - token_store: HashMap::new(), - tokens_to_refresh: HashMap::new(), + token_store: DashMap::new(), + tokens_to_refresh: DashMap::new(), } } } @@ -40,15 +38,19 @@ impl MemoryProvider { #[async_trait] impl DataSource for MemoryProvider { async fn get_tokens(&self) -> EdgeResult> { - Ok(self.token_store.values().into_iter().cloned().collect()) + Ok(self.token_store.iter().map(|x| x.value().clone()).collect()) } async fn get_token(&self, secret: &str) -> EdgeResult> { - Ok(self.token_store.get(secret).cloned()) + Ok(self.token_store.get(secret).map(|x| x.clone())) } async fn get_refresh_tokens(&self) -> EdgeResult> { - Ok(self.tokens_to_refresh.values().cloned().collect()) + Ok(self + .tokens_to_refresh + .iter() + .map(|x| x.value().clone()) + .collect()) } async fn get_client_features(&self, token: &EdgeToken) -> EdgeResult> { @@ -58,27 +60,23 @@ impl DataSource for MemoryProvider { #[async_trait] impl DataSink for MemoryProvider { - async fn sink_tokens(&mut self, tokens: Vec) -> EdgeResult<()> { + async fn sink_tokens(&self, tokens: Vec) -> EdgeResult<()> { for token in tokens { self.token_store.insert(token.token.clone(), token.clone()); } Ok(()) } - async fn set_refresh_tokens(&mut self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> { - let new_tokens = tokens - .into_iter() - .map(|token| (token.token.token.clone(), token.clone())) - .collect(); - self.tokens_to_refresh = new_tokens; + async fn set_refresh_tokens(&self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> { + self.tokens_to_refresh.clear(); + tokens.into_iter().for_each(|refresh| { + self.tokens_to_refresh + .insert(refresh.token.token.clone(), refresh.clone()); + }); Ok(()) } - async fn sink_features( - &mut self, - token: &EdgeToken, - features: ClientFeatures, - ) -> EdgeResult<()> { + async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()> { self.data_store .entry(key(token)) .and_modify(|data| { @@ -88,19 +86,19 @@ impl DataSink for MemoryProvider { Ok(()) } - async fn update_last_check(&mut self, token: &EdgeToken) -> EdgeResult<()> { - if let Some(token) = self.tokens_to_refresh.get_mut(&token.token) { + async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()> { + if let Some(mut token) = self.tokens_to_refresh.get_mut(&token.token) { token.last_check = Some(chrono::Utc::now()); } Ok(()) } async fn update_last_refresh( - &mut self, + &self, token: &EdgeToken, etag: Option, ) -> EdgeResult<()> { - if let Some(token) = self.tokens_to_refresh.get_mut(&token.token) { + if let Some(mut token) = self.tokens_to_refresh.get_mut(&token.token) { token.last_check = Some(chrono::Utc::now()); token.last_refreshed = Some(chrono::Utc::now()); token.etag = etag; @@ -117,7 +115,7 @@ mod tests { #[tokio::test] async fn memory_provider_correctly_deduplicates_tokens() { - let mut provider = MemoryProvider::default(); + let provider = MemoryProvider::default(); provider .sink_tokens(vec![EdgeToken { token: "*:development.1d38eefdd7bf72676122b008dcf330f2f2aa2f3031438e1b7e8f0d1f" @@ -141,7 +139,7 @@ mod tests { #[tokio::test] async fn memory_provider_correctly_determines_token_to_be_valid() { - let mut provider = MemoryProvider::default(); + let provider = MemoryProvider::default(); provider .sink_tokens(vec![EdgeToken { token: "*:development.1d38eefdd7bf72676122b008dcf330f2f2aa2f3031438e1b7e8f0d1f" diff --git a/server/src/data_sources/redis_provider.rs b/server/src/data_sources/redis_provider.rs index b6390c8e..fe2744c6 100644 --- a/server/src/data_sources/redis_provider.rs +++ b/server/src/data_sources/redis_provider.rs @@ -80,7 +80,7 @@ impl DataSource for RedisProvider { #[async_trait] impl DataSink for RedisProvider { - async fn sink_tokens(&mut self, tokens: Vec) -> EdgeResult<()> { + async fn sink_tokens(&self, tokens: Vec) -> EdgeResult<()> { let mut client = self.redis_client.write().await; let raw_stored_tokens: Option = client.get(TOKENS_KEY)?; @@ -99,7 +99,7 @@ impl DataSink for RedisProvider { Ok(()) } - async fn set_refresh_tokens(&mut self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> { + async fn set_refresh_tokens(&self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> { let mut client = self.redis_client.write().await; let serialized_refresh_tokens = serde_json::to_string(&tokens)?; @@ -108,11 +108,7 @@ impl DataSink for RedisProvider { Ok(()) } - async fn sink_features( - &mut self, - token: &EdgeToken, - features: ClientFeatures, - ) -> EdgeResult<()> { + async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()> { let mut client = self.redis_client.write().await; let raw_stored_features: Option = client.get(key(token))?; @@ -130,7 +126,7 @@ impl DataSink for RedisProvider { Ok(()) } - async fn update_last_check(&mut self, token: &EdgeToken) -> EdgeResult<()> { + async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()> { let mut client = self.redis_client.write().await; let raw_refresh_tokens: Option = client.get(REFRESH_TOKENS_KEY)?; @@ -155,7 +151,7 @@ impl DataSink for RedisProvider { } async fn update_last_refresh( - &mut self, + &self, token: &EdgeToken, etag: Option, ) -> EdgeResult<()> { diff --git a/server/src/data_sources/repository.rs b/server/src/data_sources/repository.rs index c5dbb87f..91d4ea6a 100644 --- a/server/src/data_sources/repository.rs +++ b/server/src/data_sources/repository.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use actix_web::http::header::EntityTag; use async_trait::async_trait; use chrono::{Duration, Utc}; -use tokio::sync::RwLock; use unleash_types::client_features::{ClientFeature, ClientFeatures}; use crate::{ @@ -37,10 +36,10 @@ impl ProjectFilter for Vec { #[derive(Clone)] pub struct DataSourceFacade { pub(crate) features_refresh_interval: Option, - pub(crate) token_source: Arc>, - pub(crate) feature_source: Arc>, - pub token_sink: Arc>, - pub feature_sink: Arc>, + pub(crate) token_source: Arc, + pub(crate) feature_source: Arc, + pub token_sink: Arc, + pub feature_sink: Arc, } impl EdgeSource for DataSourceFacade {} @@ -56,16 +55,12 @@ pub trait DataSource: Send + Sync { #[async_trait] pub trait DataSink: Send + Sync { - async fn sink_tokens(&mut self, tokens: Vec) -> EdgeResult<()>; - async fn set_refresh_tokens(&mut self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()>; - async fn sink_features( - &mut self, - token: &EdgeToken, - features: ClientFeatures, - ) -> EdgeResult<()>; - async fn update_last_check(&mut self, token: &EdgeToken) -> EdgeResult<()>; + async fn sink_tokens(&self, tokens: Vec) -> EdgeResult<()>; + async fn set_refresh_tokens(&self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()>; + async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()>; + async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()>; async fn update_last_refresh( - &mut self, + &self, token: &EdgeToken, etag: Option, ) -> EdgeResult<()>; @@ -74,13 +69,11 @@ pub trait DataSink: Send + Sync { #[async_trait] impl TokenSource for DataSourceFacade { async fn get_tokens(&self) -> EdgeResult> { - let lock = self.token_source.read().await; - lock.get_tokens().await + self.token_source.get_tokens().await } async fn get_valid_tokens(&self) -> EdgeResult> { - let lock = self.token_source.read().await; - lock.get_tokens().await.map(|result| { + self.token_source.get_tokens().await.map(|result| { result .iter() .filter(|t| t.status == TokenValidationStatus::Validated) @@ -90,23 +83,21 @@ impl TokenSource for DataSourceFacade { } async fn get_token(&self, secret: String) -> EdgeResult> { - let lock = self.token_source.read().await; - lock.get_token(secret.as_str()).await + self.token_source.get_token(secret.as_str()).await } async fn filter_valid_tokens(&self, tokens: Vec) -> EdgeResult> { - let mut known_tokens = self.token_source.read().await.get_tokens().await?; + let mut known_tokens = self.token_source.get_tokens().await?; known_tokens.retain(|t| tokens.contains(&t.token)); Ok(known_tokens) } async fn get_tokens_due_for_refresh(&self) -> EdgeResult> { - let lock = self.token_source.read().await; - let refresh_tokens = lock.get_refresh_tokens().await?; + let refresh_tokens = self.token_source.get_refresh_tokens().await?; let refresh_interval = self .features_refresh_interval - .ok_or(EdgeError::DataSourceError("No refresh interval set".into()))?; + .ok_or_else(|| EdgeError::DataSourceError("No refresh interval set".into()))?; Ok(refresh_tokens .iter() @@ -127,65 +118,54 @@ impl FeatureSource for DataSourceFacade { let token = self .get_token(token.token.clone()) .await? - .unwrap_or(token.clone()); + .unwrap_or_else(|| token.clone()); - let environment_features = self - .feature_source - .read() - .await - .get_client_features(&token) - .await?; + let environment_features = self.feature_source.get_client_features(&token).await?; Ok(environment_features .map(|client_features| ClientFeatures { features: client_features.features.filter_by_projects(&token), ..client_features }) - .ok_or(EdgeError::DataSourceError("No features found".into()))?) + .ok_or_else(|| EdgeError::DataSourceError("No features found".into()))?) } } #[async_trait] impl TokenSink for DataSourceFacade { async fn sink_tokens(&self, tokens: Vec) -> EdgeResult<()> { - let mut lock = self.token_sink.write().await; - lock.sink_tokens(tokens.clone()).await?; - drop(lock); + self.token_sink.sink_tokens(tokens.clone()).await?; - let refresh_tokens: Vec = tokens + let refresh_tokens = tokens .into_iter() .filter(|t| t.token_type == Some(TokenType::Client)) - .map(TokenRefresh::new) - .collect(); + .map(TokenRefresh::new); - let lock = self.token_source.write().await; - let current_refresh_tokens: Vec = lock + let current_refresh_tokens: Vec = self + .token_source .get_refresh_tokens() .await? .into_iter() - .chain(refresh_tokens.into_iter()) + .chain(refresh_tokens) .collect(); - drop(lock); - let mut lock = self.token_sink.write().await; + let reduced_refresh_tokens = crate::tokens::simplify(¤t_refresh_tokens); - lock.set_refresh_tokens(reduced_refresh_tokens).await + self.token_sink + .set_refresh_tokens(reduced_refresh_tokens) + .await } } #[async_trait] impl FeatureSink for DataSourceFacade { async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()> { - let mut lock = self.feature_sink.write().await; - - lock.sink_features(token, features).await?; - + self.feature_sink.sink_features(token, features).await?; Ok(()) } async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()> { - let mut lock = self.feature_sink.write().await; - lock.update_last_check(token).await?; + self.feature_sink.update_last_check(token).await?; Ok(()) } @@ -194,8 +174,7 @@ impl FeatureSink for DataSourceFacade { token: &EdgeToken, etag: Option, ) -> EdgeResult<()> { - let mut lock = self.feature_sink.write().await; - lock.update_last_refresh(token, etag).await?; + self.feature_sink.update_last_refresh(token, etag).await?; Ok(()) } } @@ -205,7 +184,6 @@ mod tests { use std::{str::FromStr, sync::Arc}; use chrono::Duration; - use tokio::sync::RwLock; use unleash_types::client_features::{ClientFeature, ClientFeatures}; use crate::{ @@ -216,7 +194,7 @@ mod tests { use super::DataSourceFacade; fn build_data_source() -> EdgeResult<(Arc, Arc)> { - let data_store = Arc::new(RwLock::new(MemoryProvider::new())); + let data_store = Arc::new(MemoryProvider::new()); let facade = Arc::new(DataSourceFacade { token_source: data_store.clone(), feature_source: data_store.clone(), diff --git a/server/src/metrics/client_metrics.rs b/server/src/metrics/client_metrics.rs index cd63e195..df0b8b49 100644 --- a/server/src/metrics/client_metrics.rs +++ b/server/src/metrics/client_metrics.rs @@ -92,7 +92,7 @@ impl MetricsCache { .or_insert(*added_count); }); }) - .or_insert(metric.clone()); + .or_insert_with(|| metric.clone()); } } } diff --git a/server/src/middleware/validate_token.rs b/server/src/middleware/validate_token.rs index e0909ceb..dc250ab9 100644 --- a/server/src/middleware/validate_token.rs +++ b/server/src/middleware/validate_token.rs @@ -6,14 +6,13 @@ use actix_web::{ web::Data, HttpResponse, }; -use tokio::sync::RwLock; pub async fn validate_token( token: EdgeToken, req: ServiceRequest, srv: crate::middleware::as_async_middleware::Next, ) -> Result, actix_web::Error> { - let maybe_validator = req.app_data::>>(); + let maybe_validator = req.app_data::>(); let source = req .app_data::>() .unwrap() @@ -21,11 +20,7 @@ pub async fn validate_token( .into_inner(); match maybe_validator { Some(validator) => { - let known_token = validator - .write() - .await - .register_token(token.token.clone()) - .await?; + let known_token = validator.register_token(token.token.clone()).await?; let res = match known_token.status { TokenValidationStatus::Validated => match known_token.token_type { Some(TokenType::Frontend) => { diff --git a/server/tests/redis_test.rs b/server/tests/redis_test.rs index cfc2a2ac..9ee0836c 100644 --- a/server/tests/redis_test.rs +++ b/server/tests/redis_test.rs @@ -28,7 +28,7 @@ async fn redis_stores_and_returns_data_correctly() { let docker = Cli::default(); let (_client, url, _node) = setup_redis(&docker); - let mut redis: RedisProvider = RedisProvider::new(&url).unwrap(); + let redis: RedisProvider = RedisProvider::new(&url).unwrap(); let token = EdgeToken { status: TokenValidationStatus::Validated, @@ -59,7 +59,7 @@ async fn redis_stores_and_returns_tokens_correctly() { let docker = Cli::default(); let (_client, url, _node) = setup_redis(&docker); - let mut redis: RedisProvider = RedisProvider::new(&url).unwrap(); + let redis: RedisProvider = RedisProvider::new(&url).unwrap(); let token = EdgeToken { status: TokenValidationStatus::Validated, @@ -80,7 +80,7 @@ async fn redis_stores_and_returns_refresh_tokens_correctly() { let docker = Cli::default(); let (_client, url, _node) = setup_redis(&docker); - let mut redis: RedisProvider = RedisProvider::new(&url).unwrap(); + let redis: RedisProvider = RedisProvider::new(&url).unwrap(); let tokens = vec![TokenRefresh { etag: None, @@ -107,7 +107,7 @@ async fn redis_store_marks_update_correctly() { let docker = Cli::default(); let (_client, url, _node) = setup_redis(&docker); - let mut redis: RedisProvider = RedisProvider::new(&url).unwrap(); + let redis: RedisProvider = RedisProvider::new(&url).unwrap(); let token = EdgeToken { status: TokenValidationStatus::Validated,