From 3a8cd761a8cd92696c9229df1a6c3614aae261fa Mon Sep 17 00:00:00 2001 From: Christopher Kolstad Date: Mon, 6 Feb 2023 11:36:43 +0100 Subject: [PATCH] feat: switch to backing with HashMap (#40) --- server/src/data_sources/memory_provider.rs | 67 +++++++++++---------- server/src/data_sources/offline_provider.rs | 35 ++++++++--- server/src/data_sources/redis_provider.rs | 13 ++-- server/src/frontend_api.rs | 8 +-- server/src/middleware/validate_token.rs | 41 ++++++++++--- server/src/types.rs | 6 +- server/tests/redis_test.rs | 5 +- 7 files changed, 112 insertions(+), 63 deletions(-) diff --git a/server/src/data_sources/memory_provider.rs b/server/src/data_sources/memory_provider.rs index 461a3458..d5c096a4 100644 --- a/server/src/data_sources/memory_provider.rs +++ b/server/src/data_sources/memory_provider.rs @@ -1,23 +1,22 @@ use std::{collections::HashMap, sync::Arc}; +use crate::types::TokenValidationStatus; +use crate::{ + error::EdgeError, + types::{ + EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink, FeaturesSource, + TokenSink, TokenSource, ValidateTokensRequest, + }, +}; use crate::{ http::unleash_client::UnleashClient, types::{ClientFeaturesRequest, ClientFeaturesResponse}, }; use async_trait::async_trait; use dashmap::DashMap; -use std::str::FromStr; use tokio::sync::mpsc::Sender; use unleash_types::client_features::ClientFeatures; -use crate::{ - error::EdgeError, - types::{ - EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink, FeaturesSource, - TokenSink, TokenSource, ValidateTokensRequest, - }, -}; - #[derive(Debug, Clone, Default)] pub struct MemoryProvider { data_store: DashMap, @@ -69,35 +68,29 @@ impl TokenSource for MemoryProvider { Ok(self.token_store.values().into_iter().cloned().collect()) } - async fn secret_is_valid( + async fn get_token_validation_status( &self, secret: &str, sender: Arc>, - ) -> EdgeResult { - if self - .get_known_tokens() - .await? - .iter() - .any(|t| t.token == secret) - { - Ok(true) + ) -> EdgeResult { + if let Some(token) = self.token_store.get(secret) { + Ok(token.clone().status) } else { let _ = sender.send(EdgeToken::try_from(secret.to_string())?).await; - Ok(false) + Ok(TokenValidationStatus::Unknown) } } async fn token_details(&self, secret: String) -> EdgeResult> { - let tokens = self.get_known_tokens().await?; - Ok(tokens.into_iter().find(|t| t.token == secret)) + Ok(self.token_store.get(&secret).cloned()) } async fn get_valid_tokens(&self, secrets: Vec) -> EdgeResult> { - let tokens = self.get_known_tokens().await?; Ok(secrets - .into_iter() - .map(|s| EdgeToken::from_str(s.as_str()).unwrap()) - .filter(|t| tokens.iter().any(|valid| valid.token == t.token)) + .iter() + .filter_map(|s| self.token_store.get(s)) + .filter(|s| s.status == TokenValidationStatus::Validated) + .cloned() .collect()) } } @@ -125,8 +118,8 @@ impl FeatureSink for MemoryProvider { #[cfg(test)] mod test { + use std::str::FromStr; use std::sync::Arc; - use tokio::sync::mpsc; use unleash_types::client_features::ClientFeature; @@ -158,16 +151,20 @@ mod test { let _ = provider .sink_tokens(vec![EdgeToken { token: "some_secret".into(), + status: TokenValidationStatus::Validated, ..EdgeToken::default() }]) .await; let (send, _) = mpsc::channel::(32); - assert!(provider - .secret_is_valid("some_secret", Arc::new(send)) - .await - .unwrap()) + assert_eq!( + provider + .get_token_validation_status("some_secret", Arc::new(send)) + .await + .unwrap(), + TokenValidationStatus::Validated + ) } #[tokio::test] @@ -196,8 +193,14 @@ mod test { #[tokio::test] async fn memory_provider_can_yield_list_of_validated_tokens() { - let james_bond = EdgeToken::from_str("jamesbond").unwrap(); - let frank_drebin = EdgeToken::from_str("frankdrebin").unwrap(); + let james_bond = EdgeToken { + status: TokenValidationStatus::Validated, + ..EdgeToken::from_str("jamesbond").unwrap() + }; + let frank_drebin = EdgeToken { + status: TokenValidationStatus::Validated, + ..EdgeToken::from_str("frankdrebin").unwrap() + }; let mut provider = MemoryProvider::default(); let _ = provider diff --git a/server/src/data_sources/offline_provider.rs b/server/src/data_sources/offline_provider.rs index d35016d1..3948f4b8 100644 --- a/server/src/data_sources/offline_provider.rs +++ b/server/src/data_sources/offline_provider.rs @@ -1,9 +1,11 @@ use crate::error::EdgeError; use crate::types::{ ClientFeaturesResponse, EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink, - FeaturesSource, TokenSink, TokenSource, + FeaturesSource, TokenSink, TokenSource, TokenValidationStatus, }; +use actix_web::http::header::EntityTag; use async_trait::async_trait; +use std::collections::HashMap; use std::fs::File; use std::io::BufReader; use std::path::PathBuf; @@ -14,7 +16,7 @@ use unleash_types::client_features::ClientFeatures; #[derive(Debug, Clone)] pub struct OfflineProvider { pub features: ClientFeatures, - pub valid_tokens: Vec, + pub valid_tokens: HashMap, } #[async_trait] @@ -27,22 +29,32 @@ impl FeaturesSource for OfflineProvider { #[async_trait] impl TokenSource for OfflineProvider { async fn get_known_tokens(&self) -> EdgeResult> { - Ok(self.valid_tokens.clone()) + Ok(self.valid_tokens.values().cloned().collect()) } - async fn secret_is_valid(&self, secret: &str, _: Arc>) -> EdgeResult { - Ok(self.valid_tokens.iter().any(|t| t.token == secret)) + async fn get_token_validation_status( + &self, + secret: &str, + _: Arc>, + ) -> EdgeResult { + Ok(if self.valid_tokens.contains_key(secret) { + TokenValidationStatus::Validated + } else { + TokenValidationStatus::Invalid + }) } async fn token_details(&self, secret: String) -> EdgeResult> { + Ok(self.valid_tokens.get(&secret).cloned()) + } + async fn get_valid_tokens(&self, secrets: Vec) -> EdgeResult> { Ok(self .valid_tokens .clone() .into_iter() - .find(|t| t.token == secret)) - } - async fn get_valid_tokens(&self, _secrets: Vec) -> EdgeResult> { - todo!() + .filter(|(k, t)| t.status == TokenValidationStatus::Validated && secrets.contains(k)) + .map(|(_k, t)| t) + .collect()) } } @@ -60,7 +72,9 @@ impl FeatureSink for OfflineProvider { todo!() } async fn fetch_features(&mut self, _token: &EdgeToken) -> EdgeResult { - todo!() + Ok(ClientFeaturesResponse::NoUpdate(EntityTag::new_weak( + "this_provider_does_not_support_refreshing_features".into(), + ))) } } #[async_trait] @@ -98,6 +112,7 @@ impl OfflineProvider { .into_iter() .map(EdgeToken::try_from) .filter_map(|t| t.ok()) + .map(|t| (t.token.clone(), t)) .collect(), } } diff --git a/server/src/data_sources/redis_provider.rs b/server/src/data_sources/redis_provider.rs index 95a2f715..676027fc 100644 --- a/server/src/data_sources/redis_provider.rs +++ b/server/src/data_sources/redis_provider.rs @@ -7,6 +7,7 @@ use unleash_types::client_features::ClientFeatures; pub const FEATURE_KEY: &str = "features"; pub const TOKENS_KEY: &str = "tokens"; +use crate::types::TokenValidationStatus; use crate::{ error::EdgeError, types::{ @@ -89,21 +90,21 @@ impl TokenSource for RedisProvider { .collect()) } - async fn secret_is_valid( + async fn get_token_validation_status( &self, secret: &str, sender: Arc>, - ) -> EdgeResult { - if self + ) -> EdgeResult { + if let Some(t) = self .get_known_tokens() .await? .iter() - .any(|t| t.token == secret) + .find(|t| t.token == secret) { - Ok(true) + Ok(t.clone().status) } else { let _ = sender.send(EdgeToken::try_from(secret.to_string())?).await; - Ok(false) + Ok(TokenValidationStatus::Unknown) } } diff --git a/server/src/frontend_api.rs b/server/src/frontend_api.rs index 46824d35..bf43b501 100644 --- a/server/src/frontend_api.rs +++ b/server/src/frontend_api.rs @@ -126,7 +126,7 @@ mod tests { use crate::data_sources::builder::DataProviderPair; use crate::types::{ ClientFeaturesResponse, EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, - FeatureSink, FeaturesSource, TokenSink, TokenSource, + FeatureSink, FeaturesSource, TokenSink, TokenSource, TokenValidationStatus, }; use actix_web::{ http::header::ContentType, @@ -177,12 +177,12 @@ mod tests { todo!() } - async fn secret_is_valid( + async fn get_token_validation_status( &self, _secret: &str, _: Arc>, - ) -> EdgeResult { - Ok(true) + ) -> EdgeResult { + Ok(TokenValidationStatus::Validated) } async fn token_details(&self, _secret: String) -> EdgeResult> { diff --git a/server/src/middleware/validate_token.rs b/server/src/middleware/validate_token.rs index 3d0a36e8..39390c9c 100644 --- a/server/src/middleware/validate_token.rs +++ b/server/src/middleware/validate_token.rs @@ -5,7 +5,7 @@ use actix_web::{ HttpResponse, }; -use crate::types::{EdgeSource, EdgeToken}; +use crate::types::{EdgeSource, EdgeToken, TokenType, TokenValidationStatus}; use tokio::sync::{mpsc::Sender, RwLock}; pub async fn validate_token( @@ -15,16 +15,41 @@ pub async fn validate_token( req: ServiceRequest, srv: crate::middleware::as_async_middleware::Next, ) -> Result, actix_web::Error> { - let res = if provider + let res = match provider .read() .await - .secret_is_valid(token.token.as_str(), sender.into_inner()) - .await? + .get_token_validation_status(token.token.as_str(), sender.into_inner()) + .await { - srv.call(req).await?.map_into_left_body() - } else { - req.into_response(HttpResponse::Forbidden().finish()) - .map_into_right_body() + Ok(TokenValidationStatus::Validated) => { + if req.path().contains("/api/frontend") || req.path().contains("/api/proxy") { + if token.token_type == Some(TokenType::Frontend) { + srv.call(req).await?.map_into_left_body() + } else { + req.into_response(HttpResponse::Forbidden().finish()) + .map_into_right_body() + } + } else if req.path().contains("/api/client") { + if token.token_type == Some(TokenType::Client) { + srv.call(req).await?.map_into_left_body() + } else { + req.into_response(HttpResponse::Forbidden().finish()) + .map_into_right_body() + } + } else { + req.into_response(HttpResponse::NotFound().finish()) + .map_into_right_body() + } + } + Ok(TokenValidationStatus::Unknown) => req + .into_response(HttpResponse::Unauthorized().finish()) + .map_into_right_body(), + Ok(TokenValidationStatus::Invalid) => req + .into_response(HttpResponse::Forbidden().finish()) + .map_into_right_body(), + Err(_e) => req + .into_response(HttpResponse::Unauthorized().finish()) + .map_into_right_body(), }; Ok(res) } diff --git a/server/src/types.rs b/server/src/types.rs index 27da456b..26b1ef17 100644 --- a/server/src/types.rs +++ b/server/src/types.rs @@ -208,7 +208,11 @@ pub trait FeaturesSource { #[async_trait] pub trait TokenSource { async fn get_known_tokens(&self) -> EdgeResult>; - async fn secret_is_valid(&self, secret: &str, job: Arc>) -> EdgeResult; + async fn get_token_validation_status( + &self, + secret: &str, + job: Arc>, + ) -> EdgeResult; async fn token_details(&self, secret: String) -> EdgeResult>; async fn get_valid_tokens(&self, tokens: Vec) -> EdgeResult>; } diff --git a/server/tests/redis_test.rs b/server/tests/redis_test.rs index 630dd87f..54df389b 100644 --- a/server/tests/redis_test.rs +++ b/server/tests/redis_test.rs @@ -3,6 +3,7 @@ use std::{fs, sync::Arc}; use redis::{Client, Commands}; use testcontainers::{clients::Cli, images::redis::Redis, Container}; use tokio::sync::mpsc; +use unleash_edge::types::TokenValidationStatus; use unleash_edge::{ data_sources::redis_provider::{RedisProvider, FEATURE_KEY, TOKENS_KEY}, types::{EdgeProvider, EdgeToken}, @@ -66,8 +67,8 @@ async fn redis_provider_correctly_determines_secret_to_be_valid() { let provider: Box = Box::new(RedisProvider::new(&url).unwrap()); let is_valid_token = provider - .secret_is_valid(TOKEN, Arc::new(send)) + .get_token_validation_status(TOKEN, Arc::new(send)) .await .unwrap(); - assert!(is_valid_token) + assert_eq!(is_valid_token, TokenValidationStatus::Validated) }