diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index c7a652b4..72507ea9 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -192,14 +192,13 @@ // TODO this needs review to reduce the bloat pulled in by dependencies #![allow(clippy::multiple_crate_versions)] -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - use crate::counter::Counter; use crate::errors::LimitadorError; use crate::limit::{Limit, Namespace}; use crate::storage::in_memory::InMemoryStorage; use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; #[macro_use] extern crate core; @@ -480,10 +479,10 @@ impl RateLimiter { values: &HashMap, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - + let ctx = values.into(); limits .iter() - .filter(|lim| lim.applies(values)) + .filter(|lim| lim.applies(&ctx)) .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect() } @@ -657,10 +656,10 @@ impl AsyncRateLimiter { values: &HashMap, ) -> LimitadorResult> { let limits = self.storage.get_limits(namespace); - + let ctx = values.into(); limits .iter() - .filter(|lim| lim.applies(values)) + .filter(|lim| lim.applies(&ctx)) .map(|lim| Counter::new(Arc::clone(lim), values.clone())) .collect() } diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index 5ad0c56e..3b0bc25c 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -135,7 +135,7 @@ impl Limit { &self, vars: HashMap, ) -> Result, EvaluationError> { - let ctx = Context::new(self, String::default(), &vars); + let ctx = Context::new(String::default(), &vars); let mut map = BTreeMap::new(); for variable in &self.variables { let name = variable.source().into(); @@ -162,17 +162,17 @@ impl Limit { .any(|v| v.as_str() == var) } - pub fn applies(&self, values: &HashMap) -> bool { - let ctx = Context::new(self, String::default(), values); + pub fn applies(&self, ctx: &Context) -> bool { + let ctx = ctx.for_limit(self); let all_conditions_apply = self .conditions .iter() - .all(|predicate| predicate.test(&ctx).unwrap()); + .all(|predicate| predicate.test(&ctx.for_limit(self)).unwrap()); let all_vars_are_set = self .variables .iter() - .all(|var| values.contains_key(var.source())); + .all(|var| ctx.has_variable(var.source())); all_conditions_apply && all_vars_are_set } @@ -252,7 +252,7 @@ mod tests { values.insert("x".into(), "5".into()); values.insert("y".into(), "1".into()); - assert!(limit.applies(&values)) + assert!(limit.applies(&(&values).into())) } #[test] @@ -269,7 +269,7 @@ mod tests { values.insert("x".into(), "1".into()); values.insert("y".into(), "1".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&(&values).into())) } #[test] @@ -287,7 +287,7 @@ mod tests { values.insert("a".into(), "1".into()); values.insert("y".into(), "1".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&(&values).into())) } #[test] @@ -304,7 +304,7 @@ mod tests { let mut values: HashMap = HashMap::new(); values.insert("x".into(), "5".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&(&values).into())) } #[test] @@ -325,7 +325,7 @@ mod tests { values.insert("y".into(), "2".into()); values.insert("z".into(), "1".into()); - assert!(limit.applies(&values)) + assert!(limit.applies(&(&values).into())) } #[test] @@ -346,7 +346,7 @@ mod tests { values.insert("y".into(), "2".into()); values.insert("z".into(), "1".into()); - assert!(!limit.applies(&values)) + assert!(!limit.applies(&(&values).into())) } #[test] @@ -410,10 +410,10 @@ mod tests { .expect("failed parsing!")], Vec::default(), ); - assert!(!limit.applies(&HashMap::default())); + assert!(!limit.applies(&Context::default())); limit.set_name("named_limit".to_string()); - assert!(limit.applies(&HashMap::default())); + assert!(limit.applies(&Context::default())); let limit = Limit::with_id( "my_id", @@ -426,7 +426,7 @@ mod tests { ], Vec::default(), ); - assert!(limit.applies(&HashMap::default())); + assert!(limit.applies(&(&HashMap::default()).into())); let limit = Limit::with_id( "my_id", @@ -438,6 +438,6 @@ mod tests { .expect("failed parsing!")], Vec::default(), ); - assert!(!limit.applies(&HashMap::default())); + assert!(!limit.applies(&Context::default())); } } diff --git a/limitador/src/limit/cel.rs b/limitador/src/limit/cel.rs index df0f77f2..fbc2ebb4 100644 --- a/limitador/src/limit/cel.rs +++ b/limitador/src/limit/cel.rs @@ -1,13 +1,12 @@ use crate::limit::Limit; use cel_interpreter::{ExecutionError, Value}; +pub use errors::{EvaluationError, ParseError}; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -pub use errors::{EvaluationError, ParseError}; - pub(super) mod errors { use cel_interpreter::ExecutionError; use std::error::Error; @@ -78,8 +77,8 @@ pub struct Context<'a> { ctx: cel_interpreter::Context<'a>, } -impl Context<'_> { - pub(crate) fn new(limit: &Limit, root: String, values: &HashMap) -> Self { +impl<'a> Context<'a> { + pub(crate) fn new(root: String, values: &HashMap) -> Self { let mut ctx = cel_interpreter::Context::default(); if root.is_empty() { @@ -91,6 +90,17 @@ impl Context<'_> { ctx.add_variable_from_value(root, Value::Map(map)); } + Self { + variables: values.keys().cloned().collect(), + ctx, + } + } + + pub(crate) fn for_limit<'b>(&'b self, limit: &Limit) -> Self + where + 'b: 'a, + { + let mut inner = self.ctx.new_inner_scope(); let limit_data = cel_interpreter::objects::Map::from(HashMap::from([ ( "name", @@ -109,13 +119,28 @@ impl Context<'_> { .unwrap_or(Value::Null), ), ])); - ctx.add_variable_from_value("limit", Value::Map(limit_data)); - + inner.add_variable_from_value("limit", Value::Map(limit_data)); Self { - variables: values.keys().cloned().collect(), - ctx, + variables: self.variables.clone(), + ctx: inner, } } + + pub(crate) fn has_variable(&self, name: &str) -> bool { + self.variables.contains(name) + } +} + +impl Default for Context<'_> { + fn default() -> Self { + Self::new(String::default(), &HashMap::default()) + } +} + +impl From<&HashMap> for Context<'_> { + fn from(value: &HashMap) -> Self { + Self::new(String::default(), value) + } } #[derive(Clone, Debug, Serialize, Deserialize)]