Skip to content

Commit

Permalink
feat: switch to backing with HashMap<TokenString, EdgeToken> (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher Kolstad authored Feb 6, 2023
1 parent 286dfd5 commit 3a8cd76
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 63 deletions.
67 changes: 35 additions & 32 deletions server/src/data_sources/memory_provider.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
use std::{collections::HashMap, sync::Arc};

use crate::types::TokenValidationStatus;
use crate::{
error::EdgeError,
types::{
EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink, FeaturesSource,
TokenSink, TokenSource, ValidateTokensRequest,
},
};
use crate::{
http::unleash_client::UnleashClient,
types::{ClientFeaturesRequest, ClientFeaturesResponse},
};
use async_trait::async_trait;
use dashmap::DashMap;
use std::str::FromStr;
use tokio::sync::mpsc::Sender;
use unleash_types::client_features::ClientFeatures;

use crate::{
error::EdgeError,
types::{
EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink, FeaturesSource,
TokenSink, TokenSource, ValidateTokensRequest,
},
};

#[derive(Debug, Clone, Default)]
pub struct MemoryProvider {
data_store: DashMap<String, ClientFeatures>,
Expand Down Expand Up @@ -69,35 +68,29 @@ impl TokenSource for MemoryProvider {
Ok(self.token_store.values().into_iter().cloned().collect())
}

async fn secret_is_valid(
async fn get_token_validation_status(
&self,
secret: &str,
sender: Arc<Sender<EdgeToken>>,
) -> EdgeResult<bool> {
if self
.get_known_tokens()
.await?
.iter()
.any(|t| t.token == secret)
{
Ok(true)
) -> EdgeResult<TokenValidationStatus> {
if let Some(token) = self.token_store.get(secret) {
Ok(token.clone().status)
} else {
let _ = sender.send(EdgeToken::try_from(secret.to_string())?).await;
Ok(false)
Ok(TokenValidationStatus::Unknown)
}
}

async fn token_details(&self, secret: String) -> EdgeResult<Option<EdgeToken>> {
let tokens = self.get_known_tokens().await?;
Ok(tokens.into_iter().find(|t| t.token == secret))
Ok(self.token_store.get(&secret).cloned())
}

async fn get_valid_tokens(&self, secrets: Vec<String>) -> EdgeResult<Vec<EdgeToken>> {
let tokens = self.get_known_tokens().await?;
Ok(secrets
.into_iter()
.map(|s| EdgeToken::from_str(s.as_str()).unwrap())
.filter(|t| tokens.iter().any(|valid| valid.token == t.token))
.iter()
.filter_map(|s| self.token_store.get(s))
.filter(|s| s.status == TokenValidationStatus::Validated)
.cloned()
.collect())
}
}
Expand Down Expand Up @@ -125,8 +118,8 @@ impl FeatureSink for MemoryProvider {

#[cfg(test)]
mod test {
use std::str::FromStr;
use std::sync::Arc;

use tokio::sync::mpsc;
use unleash_types::client_features::ClientFeature;

Expand Down Expand Up @@ -158,16 +151,20 @@ mod test {
let _ = provider
.sink_tokens(vec![EdgeToken {
token: "some_secret".into(),
status: TokenValidationStatus::Validated,
..EdgeToken::default()
}])
.await;

let (send, _) = mpsc::channel::<EdgeToken>(32);

assert!(provider
.secret_is_valid("some_secret", Arc::new(send))
.await
.unwrap())
assert_eq!(
provider
.get_token_validation_status("some_secret", Arc::new(send))
.await
.unwrap(),
TokenValidationStatus::Validated
)
}

#[tokio::test]
Expand Down Expand Up @@ -196,8 +193,14 @@ mod test {

#[tokio::test]
async fn memory_provider_can_yield_list_of_validated_tokens() {
let james_bond = EdgeToken::from_str("jamesbond").unwrap();
let frank_drebin = EdgeToken::from_str("frankdrebin").unwrap();
let james_bond = EdgeToken {
status: TokenValidationStatus::Validated,
..EdgeToken::from_str("jamesbond").unwrap()
};
let frank_drebin = EdgeToken {
status: TokenValidationStatus::Validated,
..EdgeToken::from_str("frankdrebin").unwrap()
};

let mut provider = MemoryProvider::default();
let _ = provider
Expand Down
35 changes: 25 additions & 10 deletions server/src/data_sources/offline_provider.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::error::EdgeError;
use crate::types::{
ClientFeaturesResponse, EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken, FeatureSink,
FeaturesSource, TokenSink, TokenSource,
FeaturesSource, TokenSink, TokenSource, TokenValidationStatus,
};
use actix_web::http::header::EntityTag;
use async_trait::async_trait;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
Expand All @@ -14,7 +16,7 @@ use unleash_types::client_features::ClientFeatures;
#[derive(Debug, Clone)]
pub struct OfflineProvider {
pub features: ClientFeatures,
pub valid_tokens: Vec<EdgeToken>,
pub valid_tokens: HashMap<String, EdgeToken>,
}

#[async_trait]
Expand All @@ -27,22 +29,32 @@ impl FeaturesSource for OfflineProvider {
#[async_trait]
impl TokenSource for OfflineProvider {
async fn get_known_tokens(&self) -> EdgeResult<Vec<EdgeToken>> {
Ok(self.valid_tokens.clone())
Ok(self.valid_tokens.values().cloned().collect())
}

async fn secret_is_valid(&self, secret: &str, _: Arc<Sender<EdgeToken>>) -> EdgeResult<bool> {
Ok(self.valid_tokens.iter().any(|t| t.token == secret))
async fn get_token_validation_status(
&self,
secret: &str,
_: Arc<Sender<EdgeToken>>,
) -> EdgeResult<TokenValidationStatus> {
Ok(if self.valid_tokens.contains_key(secret) {
TokenValidationStatus::Validated
} else {
TokenValidationStatus::Invalid
})
}

async fn token_details(&self, secret: String) -> EdgeResult<Option<EdgeToken>> {
Ok(self.valid_tokens.get(&secret).cloned())
}
async fn get_valid_tokens(&self, secrets: Vec<String>) -> EdgeResult<Vec<EdgeToken>> {
Ok(self
.valid_tokens
.clone()
.into_iter()
.find(|t| t.token == secret))
}
async fn get_valid_tokens(&self, _secrets: Vec<String>) -> EdgeResult<Vec<EdgeToken>> {
todo!()
.filter(|(k, t)| t.status == TokenValidationStatus::Validated && secrets.contains(k))
.map(|(_k, t)| t)
.collect())
}
}

Expand All @@ -60,7 +72,9 @@ impl FeatureSink for OfflineProvider {
todo!()
}
async fn fetch_features(&mut self, _token: &EdgeToken) -> EdgeResult<ClientFeaturesResponse> {
todo!()
Ok(ClientFeaturesResponse::NoUpdate(EntityTag::new_weak(
"this_provider_does_not_support_refreshing_features".into(),
)))
}
}
#[async_trait]
Expand Down Expand Up @@ -98,6 +112,7 @@ impl OfflineProvider {
.into_iter()
.map(EdgeToken::try_from)
.filter_map(|t| t.ok())
.map(|t| (t.token.clone(), t))
.collect(),
}
}
Expand Down
13 changes: 7 additions & 6 deletions server/src/data_sources/redis_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use unleash_types::client_features::ClientFeatures;
pub const FEATURE_KEY: &str = "features";
pub const TOKENS_KEY: &str = "tokens";

use crate::types::TokenValidationStatus;
use crate::{
error::EdgeError,
types::{
Expand Down Expand Up @@ -89,21 +90,21 @@ impl TokenSource for RedisProvider {
.collect())
}

async fn secret_is_valid(
async fn get_token_validation_status(
&self,
secret: &str,
sender: Arc<Sender<EdgeToken>>,
) -> EdgeResult<bool> {
if self
) -> EdgeResult<TokenValidationStatus> {
if let Some(t) = self
.get_known_tokens()
.await?
.iter()
.any(|t| t.token == secret)
.find(|t| t.token == secret)
{
Ok(true)
Ok(t.clone().status)
} else {
let _ = sender.send(EdgeToken::try_from(secret.to_string())?).await;
Ok(false)
Ok(TokenValidationStatus::Unknown)
}
}

Expand Down
8 changes: 4 additions & 4 deletions server/src/frontend_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ mod tests {
use crate::data_sources::builder::DataProviderPair;
use crate::types::{
ClientFeaturesResponse, EdgeProvider, EdgeResult, EdgeSink, EdgeSource, EdgeToken,
FeatureSink, FeaturesSource, TokenSink, TokenSource,
FeatureSink, FeaturesSource, TokenSink, TokenSource, TokenValidationStatus,
};
use actix_web::{
http::header::ContentType,
Expand Down Expand Up @@ -177,12 +177,12 @@ mod tests {
todo!()
}

async fn secret_is_valid(
async fn get_token_validation_status(
&self,
_secret: &str,
_: Arc<Sender<EdgeToken>>,
) -> EdgeResult<bool> {
Ok(true)
) -> EdgeResult<TokenValidationStatus> {
Ok(TokenValidationStatus::Validated)
}

async fn token_details(&self, _secret: String) -> EdgeResult<Option<EdgeToken>> {
Expand Down
41 changes: 33 additions & 8 deletions server/src/middleware/validate_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use actix_web::{
HttpResponse,
};

use crate::types::{EdgeSource, EdgeToken};
use crate::types::{EdgeSource, EdgeToken, TokenType, TokenValidationStatus};
use tokio::sync::{mpsc::Sender, RwLock};

pub async fn validate_token(
Expand All @@ -15,16 +15,41 @@ pub async fn validate_token(
req: ServiceRequest,
srv: crate::middleware::as_async_middleware::Next<impl MessageBody + 'static>,
) -> Result<ServiceResponse<impl MessageBody>, actix_web::Error> {
let res = if provider
let res = match provider
.read()
.await
.secret_is_valid(token.token.as_str(), sender.into_inner())
.await?
.get_token_validation_status(token.token.as_str(), sender.into_inner())
.await
{
srv.call(req).await?.map_into_left_body()
} else {
req.into_response(HttpResponse::Forbidden().finish())
.map_into_right_body()
Ok(TokenValidationStatus::Validated) => {
if req.path().contains("/api/frontend") || req.path().contains("/api/proxy") {
if token.token_type == Some(TokenType::Frontend) {
srv.call(req).await?.map_into_left_body()
} else {
req.into_response(HttpResponse::Forbidden().finish())
.map_into_right_body()
}
} else if req.path().contains("/api/client") {
if token.token_type == Some(TokenType::Client) {
srv.call(req).await?.map_into_left_body()
} else {
req.into_response(HttpResponse::Forbidden().finish())
.map_into_right_body()
}
} else {
req.into_response(HttpResponse::NotFound().finish())
.map_into_right_body()
}
}
Ok(TokenValidationStatus::Unknown) => req
.into_response(HttpResponse::Unauthorized().finish())
.map_into_right_body(),
Ok(TokenValidationStatus::Invalid) => req
.into_response(HttpResponse::Forbidden().finish())
.map_into_right_body(),
Err(_e) => req
.into_response(HttpResponse::Unauthorized().finish())
.map_into_right_body(),
};
Ok(res)
}
6 changes: 5 additions & 1 deletion server/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ pub trait FeaturesSource {
#[async_trait]
pub trait TokenSource {
async fn get_known_tokens(&self) -> EdgeResult<Vec<EdgeToken>>;
async fn secret_is_valid(&self, secret: &str, job: Arc<Sender<EdgeToken>>) -> EdgeResult<bool>;
async fn get_token_validation_status(
&self,
secret: &str,
job: Arc<Sender<EdgeToken>>,
) -> EdgeResult<TokenValidationStatus>;
async fn token_details(&self, secret: String) -> EdgeResult<Option<EdgeToken>>;
async fn get_valid_tokens(&self, tokens: Vec<String>) -> EdgeResult<Vec<EdgeToken>>;
}
Expand Down
5 changes: 3 additions & 2 deletions server/tests/redis_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{fs, sync::Arc};
use redis::{Client, Commands};
use testcontainers::{clients::Cli, images::redis::Redis, Container};
use tokio::sync::mpsc;
use unleash_edge::types::TokenValidationStatus;
use unleash_edge::{
data_sources::redis_provider::{RedisProvider, FEATURE_KEY, TOKENS_KEY},
types::{EdgeProvider, EdgeToken},
Expand Down Expand Up @@ -66,8 +67,8 @@ async fn redis_provider_correctly_determines_secret_to_be_valid() {
let provider: Box<dyn EdgeProvider> = Box::new(RedisProvider::new(&url).unwrap());

let is_valid_token = provider
.secret_is_valid(TOKEN, Arc::new(send))
.get_token_validation_status(TOKEN, Arc::new(send))
.await
.unwrap();
assert!(is_valid_token)
assert_eq!(is_valid_token, TokenValidationStatus::Validated)
}

0 comments on commit 3a8cd76

Please sign in to comment.