Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: removal of RW locks for dashmaps #74

Merged
merged 2 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions server/src/auth/token_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct TokenValidator {

impl TokenValidator {
async fn get_unknown_and_known_tokens(
&mut self,
&self,
tokens: Vec<String>,
) -> EdgeResult<(Vec<EdgeToken>, Vec<EdgeToken>)> {
let tokens_with_valid_format: Vec<EdgeToken> = tokens
Expand All @@ -32,7 +32,7 @@ impl TokenValidator {
}
}

pub async fn register_token(&mut self, token: String) -> EdgeResult<EdgeToken> {
pub async fn register_token(&self, token: String) -> EdgeResult<EdgeToken> {
Ok(self
.register_tokens(vec![token])
.await?
Expand All @@ -41,7 +41,7 @@ impl TokenValidator {
.clone())
}

pub async fn register_tokens(&mut self, tokens: Vec<String>) -> EdgeResult<Vec<EdgeToken>> {
pub async fn register_tokens(&self, tokens: Vec<String>) -> EdgeResult<Vec<EdgeToken>> {
let (unknown_tokens, known_tokens) = self.get_unknown_and_known_tokens(tokens).await?;
if unknown_tokens.is_empty() {
Ok(known_tokens)
Expand Down Expand Up @@ -94,7 +94,6 @@ mod tests {
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct EdgeTokens {
Expand Down Expand Up @@ -132,7 +131,7 @@ mod tests {

#[tokio::test]
pub async fn can_validate_tokens() {
let test_provider = Arc::new(RwLock::new(MemoryProvider::default()));
let test_provider = Arc::new(MemoryProvider::default());
let facade = Arc::new(DataSourceFacade {
features_refresh_interval: Some(Duration::seconds(1)),
token_source: test_provider.clone(),
Expand All @@ -149,7 +148,7 @@ mod tests {
crate::http::unleash_client::UnleashClient::new(srv.url("/").as_str(), None)
.expect("Couldn't build client");

let mut validation_holder = super::TokenValidator {
let validation_holder = super::TokenValidator {
unleash_client: Arc::new(unleash_client),
edge_source: source,
edge_sink: sink,
Expand All @@ -163,8 +162,6 @@ mod tests {
.await
.expect("Couldn't register tokens");
let known_tokens = test_provider
.read()
.await
.get_tokens()
.await
.expect("Couldn't get tokens");
Expand All @@ -180,7 +177,7 @@ mod tests {

#[tokio::test]
pub async fn tokens_with_wrong_format_is_not_included() {
let test_provider = Arc::new(RwLock::new(MemoryProvider::default()));
let test_provider = Arc::new(MemoryProvider::default());
let facade = Arc::new(DataSourceFacade {
features_refresh_interval: Some(Duration::seconds(1)),
feature_source: test_provider.clone(),
Expand All @@ -195,7 +192,7 @@ mod tests {
let unleash_client =
crate::http::unleash_client::UnleashClient::new(srv.url("/").as_str(), None)
.expect("Couldn't build client");
let mut validation_holder = super::TokenValidator {
let validation_holder = super::TokenValidator {
unleash_client: Arc::new(unleash_client),
edge_source: source,
edge_sink: sink,
Expand Down
6 changes: 3 additions & 3 deletions server/src/data_sources/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn build_offline(offline_args: OfflineArgs) -> EdgeResult<Arc<dyn EdgeSource>> {
}

fn build_memory(features_refresh_interval_seconds: Duration) -> EdgeResult<DataProviderPair> {
let data_source = Arc::new(RwLock::new(MemoryProvider::new()));
let data_source = Arc::new(MemoryProvider::new());
let facade = Arc::new(DataSourceFacade {
features_refresh_interval: Some(features_refresh_interval_seconds),
token_source: data_source.clone(),
Expand All @@ -58,7 +58,7 @@ fn build_redis(
redis_url: String,
features_refresh_interval_seconds: Duration,
) -> EdgeResult<DataProviderPair> {
let data_source = Arc::new(RwLock::new(RedisProvider::new(&redis_url)?));
let data_source = Arc::new(RedisProvider::new(&redis_url)?);
let facade = Arc::new(DataSourceFacade {
token_source: data_source.clone(),
feature_source: data_source.clone(),
Expand Down Expand Up @@ -93,7 +93,7 @@ pub async fn build_source_and_sink(args: CliArgs) -> EdgeResult<RepositoryInfo>
EdgeArg::InMemory => build_memory(refresh_interval),
}?;

let mut token_validator = TokenValidator {
let token_validator = TokenValidator {
unleash_client: Arc::new(unleash_client.clone()),
edge_source: source.clone(),
edge_sink: sink.clone(),
Expand Down
52 changes: 25 additions & 27 deletions server/src/data_sources/memory_provider.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashMap;

use crate::types::TokenRefresh;
use crate::types::{EdgeResult, EdgeToken};
use actix_web::http::header::EntityTag;
Expand All @@ -13,8 +11,8 @@ use super::repository::{DataSink, DataSource};
#[derive(Debug, Clone)]
pub struct MemoryProvider {
data_store: DashMap<String, ClientFeatures>,
token_store: HashMap<String, EdgeToken>,
tokens_to_refresh: HashMap<String, TokenRefresh>,
token_store: DashMap<String, EdgeToken>,
tokens_to_refresh: DashMap<String, TokenRefresh>,
}

fn key(token: &EdgeToken) -> String {
Expand All @@ -31,24 +29,28 @@ impl MemoryProvider {
pub fn new() -> Self {
Self {
data_store: DashMap::new(),
token_store: HashMap::new(),
tokens_to_refresh: HashMap::new(),
token_store: DashMap::new(),
tokens_to_refresh: DashMap::new(),
}
}
}

#[async_trait]
impl DataSource for MemoryProvider {
async fn get_tokens(&self) -> EdgeResult<Vec<EdgeToken>> {
Ok(self.token_store.values().into_iter().cloned().collect())
Ok(self.token_store.iter().map(|x| x.value().clone()).collect())
}

async fn get_token(&self, secret: &str) -> EdgeResult<Option<EdgeToken>> {
Ok(self.token_store.get(secret).cloned())
Ok(self.token_store.get(secret).map(|x| x.clone()))
}

async fn get_refresh_tokens(&self) -> EdgeResult<Vec<TokenRefresh>> {
Ok(self.tokens_to_refresh.values().cloned().collect())
Ok(self
.tokens_to_refresh
.iter()
.map(|x| x.value().clone())
.collect())
}

async fn get_client_features(&self, token: &EdgeToken) -> EdgeResult<Option<ClientFeatures>> {
Expand All @@ -58,27 +60,23 @@ impl DataSource for MemoryProvider {

#[async_trait]
impl DataSink for MemoryProvider {
async fn sink_tokens(&mut self, tokens: Vec<EdgeToken>) -> EdgeResult<()> {
async fn sink_tokens(&self, tokens: Vec<EdgeToken>) -> EdgeResult<()> {
for token in tokens {
self.token_store.insert(token.token.clone(), token.clone());
}
Ok(())
}

async fn set_refresh_tokens(&mut self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> {
let new_tokens = tokens
.into_iter()
.map(|token| (token.token.token.clone(), token.clone()))
.collect();
self.tokens_to_refresh = new_tokens;
async fn set_refresh_tokens(&self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> {
self.tokens_to_refresh.clear();
tokens.into_iter().for_each(|refresh| {
self.tokens_to_refresh
.insert(refresh.token.token.clone(), refresh.clone());
});
Ok(())
}

async fn sink_features(
&mut self,
token: &EdgeToken,
features: ClientFeatures,
) -> EdgeResult<()> {
async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()> {
self.data_store
.entry(key(token))
.and_modify(|data| {
Expand All @@ -88,19 +86,19 @@ impl DataSink for MemoryProvider {
Ok(())
}

async fn update_last_check(&mut self, token: &EdgeToken) -> EdgeResult<()> {
if let Some(token) = self.tokens_to_refresh.get_mut(&token.token) {
async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()> {
if let Some(mut token) = self.tokens_to_refresh.get_mut(&token.token) {
token.last_check = Some(chrono::Utc::now());
}
Ok(())
}

async fn update_last_refresh(
&mut self,
&self,
token: &EdgeToken,
etag: Option<EntityTag>,
) -> EdgeResult<()> {
if let Some(token) = self.tokens_to_refresh.get_mut(&token.token) {
if let Some(mut token) = self.tokens_to_refresh.get_mut(&token.token) {
token.last_check = Some(chrono::Utc::now());
token.last_refreshed = Some(chrono::Utc::now());
token.etag = etag;
Expand All @@ -117,7 +115,7 @@ mod tests {

#[tokio::test]
async fn memory_provider_correctly_deduplicates_tokens() {
let mut provider = MemoryProvider::default();
let provider = MemoryProvider::default();
provider
.sink_tokens(vec![EdgeToken {
token: "*:development.1d38eefdd7bf72676122b008dcf330f2f2aa2f3031438e1b7e8f0d1f"
Expand All @@ -141,7 +139,7 @@ mod tests {

#[tokio::test]
async fn memory_provider_correctly_determines_token_to_be_valid() {
let mut provider = MemoryProvider::default();
let provider = MemoryProvider::default();
provider
.sink_tokens(vec![EdgeToken {
token: "*:development.1d38eefdd7bf72676122b008dcf330f2f2aa2f3031438e1b7e8f0d1f"
Expand Down
14 changes: 5 additions & 9 deletions server/src/data_sources/redis_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl DataSource for RedisProvider {

#[async_trait]
impl DataSink for RedisProvider {
async fn sink_tokens(&mut self, tokens: Vec<EdgeToken>) -> EdgeResult<()> {
async fn sink_tokens(&self, tokens: Vec<EdgeToken>) -> EdgeResult<()> {
let mut client = self.redis_client.write().await;
let raw_stored_tokens: Option<String> = client.get(TOKENS_KEY)?;

Expand All @@ -99,7 +99,7 @@ impl DataSink for RedisProvider {
Ok(())
}

async fn set_refresh_tokens(&mut self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> {
async fn set_refresh_tokens(&self, tokens: Vec<&TokenRefresh>) -> EdgeResult<()> {
let mut client = self.redis_client.write().await;

let serialized_refresh_tokens = serde_json::to_string(&tokens)?;
Expand All @@ -108,11 +108,7 @@ impl DataSink for RedisProvider {
Ok(())
}

async fn sink_features(
&mut self,
token: &EdgeToken,
features: ClientFeatures,
) -> EdgeResult<()> {
async fn sink_features(&self, token: &EdgeToken, features: ClientFeatures) -> EdgeResult<()> {
let mut client = self.redis_client.write().await;
let raw_stored_features: Option<String> = client.get(key(token))?;

Expand All @@ -130,7 +126,7 @@ impl DataSink for RedisProvider {
Ok(())
}

async fn update_last_check(&mut self, token: &EdgeToken) -> EdgeResult<()> {
async fn update_last_check(&self, token: &EdgeToken) -> EdgeResult<()> {
let mut client = self.redis_client.write().await;
let raw_refresh_tokens: Option<String> = client.get(REFRESH_TOKENS_KEY)?;

Expand All @@ -155,7 +151,7 @@ impl DataSink for RedisProvider {
}

async fn update_last_refresh(
&mut self,
&self,
token: &EdgeToken,
etag: Option<EntityTag>,
) -> EdgeResult<()> {
Expand Down
Loading