diff --git a/expr.go b/expr.go index 9ccf400..c8a76d5 100644 --- a/expr.go +++ b/expr.go @@ -35,10 +35,14 @@ type EvaluableLoader func(ctx context.Context, evaluableIDs ...uuid.UUID) ([]Eva // // An AggregateEvaluator instance exists for every event name being matched. type AggregateEvaluator interface { - // Add adds an expression to the tree evaluator. This returns true - // if the expression is aggregateable, or false if the expression will be - // evaluated each time an event is received. - Add(ctx context.Context, eval Evaluable) (bool, error) + // Add adds an expression to the tree evaluator. This returns the ratio + // of aggregate to slow parts in the expression, or an error if there was an + // issue. + // + // Purely aggregateable expressions have a ratio of 1. + // Mixed expressions return the ratio of fast:slow expressions, as a float. + // Slow, non-aggregateable expressions return 0. + Add(ctx context.Context, eval Evaluable) (float64, error) // Remove removes an expression from the aggregate evaluator Remove(ctx context.Context, eval Evaluable) error @@ -60,12 +64,15 @@ type AggregateEvaluator interface { // stored in the evaluator. Len() int - // AggregateableLen returns the number of expressions being matched by aggregated trees. - AggregateableLen() int + // FastLen returns the number of expressions being matched by aggregated trees. + FastLen() int + + // MixedLen returns the number of expressions being matched by aggregated trees. + MixedLen() int - // ConstantLen returns the total number of expressions that must constantly + // SlowLen returns the total number of expressions that must constantly // be matched due to non-aggregateable clauses in their expressions. - ConstantLen() int + SlowLen() int } func NewAggregateEvaluator( @@ -104,8 +111,15 @@ type aggregator struct { // lock prevents concurrent updates of data lock *sync.RWMutex - // len stores the current len of aggregable expressions. - len int32 + + // fastLen stores the current len of purely aggregable expressions. + fastLen int32 + + // mixedLen stores the current len of mixed aggregable expressions, + // eg "foo == '1' && bar != '1'". This is becasue != isn't aggregateable, + // but the first `==` is used as a prefilter. + mixedLen int32 + // constants tracks evaluable IDs that must always be evaluated, due to // the expression containing non-aggregateable clauses. constants []uuid.UUID @@ -114,17 +128,22 @@ type aggregator struct { // Len returns the total number of aggregateable and constantly matched expressions // stored in the evaluator. func (a aggregator) Len() int { - return int(a.len) + len(a.constants) + return int(a.fastLen) + int(a.mixedLen) + len(a.constants) +} + +// FastLen returns the number of expressions being matched by aggregated trees. +func (a aggregator) FastLen() int { + return int(a.fastLen) } -// AggregateableLen returns the number of expressions being matched by aggregated trees. -func (a aggregator) AggregateableLen() int { - return int(a.len) +// MixedLen returns the number of expressions being matched by aggregated trees. +func (a aggregator) MixedLen() int { + return int(a.mixedLen) } -// ConstantLen returns the total number of expressions that must constantly +// SlowLen returns the total number of expressions that must constantly // be matched due to non-aggregateable clauses in their expressions. -func (a aggregator) ConstantLen() int { +func (a aggregator) SlowLen() int { return len(a.constants) } @@ -318,14 +337,18 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([ return result, nil } -// Add adds an Evaluable to the aggregate tree engine for matching. It returns -// a boolean indicating whether the expression is suitable for aggregate tree -// matching, allowing rapid exclusion of non-matching expressions. -func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { +// Add adds an expression to the tree evaluator. This returns the ratio +// of aggregate to slow parts in the expression, or an error if there was an +// issue. +// +// Purely aggregateable expressions have a ratio of 1. +// Mixed expressions return the ratio of fast:slow expressions, as a float. +// Slow, non-aggregateable expressions return 0. +func (a *aggregator) Add(ctx context.Context, eval Evaluable) (float64, error) { // parse the expression using our tree parser. parsed, err := a.parser.Parse(ctx, eval) if err != nil { - return false, err + return -1, err } if eval.GetExpression() == "" || parsed.HasMacros { @@ -333,25 +356,42 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { a.lock.Lock() a.constants = append(a.constants, parsed.EvaluableID) a.lock.Unlock() - return false, nil + return -1, nil } + stats := &exprAggregateStats{} for _, g := range parsed.RootGroups() { - ok, err := a.iterGroup(ctx, g, parsed, a.addNode) + s, err := a.iterGroup(ctx, g, parsed, a.addNode) - if err != nil || !ok { + if err != nil { // This is the first time we're seeing a non-aggregateable // group, so add it to the constants list and don't do anything else. a.lock.Lock() a.constants = append(a.constants, parsed.EvaluableID) a.lock.Unlock() - return false, err + return -1, err } + + stats.Merge(s) } - // Track the number of added expressions correctly. - atomic.AddInt32(&a.len, 1) - return true, nil + if stats.Fast() == 0 { + // This is a non-aggregateable, slow expression. + // Add it to the constants list and don't do anything else. + a.lock.Lock() + a.constants = append(a.constants, parsed.EvaluableID) + a.lock.Unlock() + return stats.Ratio(), err + } + + if stats.Slow() == 0 { + // This is a purely aggregateable expression. + atomic.AddInt32(&a.fastLen, 1) + return stats.Ratio(), err + } + + atomic.AddInt32(&a.mixedLen, 1) + return stats.Ratio(), err } func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { @@ -365,27 +405,36 @@ func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { return err } - aggregateable := true + stats := &exprAggregateStats{} + for _, g := range parsed.RootGroups() { - ok, err := a.iterGroup(ctx, g, parsed, a.removeNode) + s, err := a.iterGroup(ctx, g, parsed, a.removeNode) if err == ErrExpressionPartNotFound { return ErrEvaluableNotFound } + if err != nil { + _ = a.removeConstantEvaluable(ctx, eval) return err } - if !ok && aggregateable { - if err := a.removeConstantEvaluable(ctx, eval); err != nil { - return err - } - aggregateable = false + stats.Merge(s) + } + + if stats.Fast() == 0 { + // This is a non-aggregateable, slow expression. + if err := a.removeConstantEvaluable(ctx, eval); err != nil { + return err } + return nil } - if aggregateable { - atomic.AddInt32(&a.len, -1) + if stats.Slow() == 0 { + // This is a purely aggregateable expression. + atomic.AddInt32(&a.fastLen, -1) + return nil } + atomic.AddInt32(&a.mixedLen, -1) return nil } @@ -408,17 +457,70 @@ func (a *aggregator) removeConstantEvaluable(ctx context.Context, eval Evaluable return nil } -func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedExpression, op nodeOp) (bool, error) { +type exprAggregateStats [2]int + +// Fast returns the number of aggregateable predicates in the iterated expr +func (e exprAggregateStats) Fast() int { + return e[0] +} + +// Slow returns the number of non-aggregateable predicates in the iterated expr +func (e exprAggregateStats) Slow() int { + return e[1] +} + +func (e *exprAggregateStats) AddFast() { + e[0] += 1 +} + +func (e *exprAggregateStats) AddSlow() { + e[1] += 1 +} + +func (e *exprAggregateStats) Merge(other exprAggregateStats) { + e[0] += other[0] + e[1] += other[1] +} + +// Ratio returns the ratio of fast to slow expressions as a float, eg. 9 fast +// aggregateable parts and 1 slow part returns a ratio of 0.9. +func (e *exprAggregateStats) Ratio() float64 { + if e[0] == 0 && e[1] == 0 { + // Failure. + return -1 + } + + if e[1] == 0 { + // Always fast, return 1 + return 1 + } + + if e[0] == 0 { + // Always slow, return 0 + return 0 + } + + // return ratio of fast:slow + return float64(e[0]) / (float64(e[0]) + float64(e[1])) +} + +// iterGroup iterates the entire expression, returning statistics on how "aggregateable" the expression is +func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedExpression, op nodeOp) (exprAggregateStats, error) { + stats := &exprAggregateStats{} + // It's possible that if there are additional branches, don't bother to add this to the aggregate tree. // Mark this as a non-exhaustive addition and skip immediately. if len(node.Ands) > 0 { for _, n := range node.Ands { if !n.HasPredicate() || len(n.Ors) > 0 { // Don't handle sub-branching for now. - return false, nil + // TODO: Recursively iterate. + stats.AddSlow() + continue } if !isAggregateable(n) { - return false, nil + stats.AddSlow() + continue } } } @@ -428,24 +530,36 @@ func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedEx all := node.Ands if node.Predicate != nil { if !isAggregateable(node) { - return false, nil + stats.AddSlow() + } else { + // Merge all of the nodes together and check whether each node is aggregateable. + all = append(node.Ands, node) } - // Merge all of the nodes together and check whether each node is aggregateable. - all = append(node.Ands, node) } // Iterate through and add every predicate to each engine. for _, n := range all { err := op(ctx, n, parsed) - if err == errEngineUnimplemented { - return false, nil - } - if err != nil { - return false, err + + switch err { + case nil: + // This is okay. + stats.AddFast() + continue + + case errEngineUnimplemented: + // Not yet added to aggregator + stats.AddSlow() + continue + + default: + // Some other error. + stats.AddSlow() + continue } } - return true, nil + return *stats, nil } func engineType(p Predicate) EngineType { diff --git a/expr_test.go b/expr_test.go index 54dc8be..d53ba20 100644 --- a/expr_test.go +++ b/expr_test.go @@ -112,7 +112,7 @@ func TestAdd(t *testing.T) { _, err := e.Add(ctx, expr) require.NoError(t, err) - require.Equal(t, 1, e.ConstantLen()) + require.Equal(t, 1, e.SlowLen()) } @@ -404,7 +404,7 @@ func TestEvaluate_Compound(t *testing.T) { e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) ok, err := e.Add(ctx, expected) - require.True(t, ok) + require.Greater(t, ok, float64(0)) require.NoError(t, err) t.Run("It matches items", func(t *testing.T) { @@ -456,7 +456,7 @@ func TestAggregateMatch(t *testing.T) { eval := tex(fmt.Sprintf(`event.data.%s == "yes"`, k)) loader.AddEval(eval) ok, err := e.Add(ctx, eval) - require.True(t, ok) + require.Greater(t, ok, float64(0)) require.NoError(t, err) } @@ -532,7 +532,7 @@ func TestMacros(t *testing.T) { loader.AddEval(eval) ok, err := e.Add(ctx, eval) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(-1)) // Not supported. t.Run("It doesn't evaluate macros", func(t *testing.T) { @@ -596,18 +596,18 @@ func TestAddRemove(t *testing.T) { ok, err := e.Add(ctx, firstExpr) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) require.Equal(t, 1, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 1, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 1, e.FastLen()) // Add the same expression again. ok, err = e.Add(ctx, loader.AddEval(tex(`event.data.foo == "yes"`, "second-id"))) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) require.Equal(t, 2, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 2, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 2, e.FastLen()) t.Run("It removes duplicate expressions with different IDs", func(t *testing.T) { // Matching this expr should work before removal. @@ -622,10 +622,11 @@ func TestAddRemove(t *testing.T) { err = e.Remove(ctx, tex(`event.data.foo == "yes"`, "second-id")) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) + require.Equal(t, 1, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 1, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 1, e.FastLen()) // Matching this expr should now fail. eval, count, err = e.Evaluate(ctx, map[string]any{ @@ -642,10 +643,11 @@ func TestAddRemove(t *testing.T) { // Add a new expression ok, err = e.Add(ctx, loader.AddEval(tex(`event.data.another == "no"`))) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) + require.Equal(t, 2, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 2, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 2, e.FastLen()) // Remove all expressions t.Run("It removes an aggregateable expression", func(t *testing.T) { @@ -661,10 +663,11 @@ func TestAddRemove(t *testing.T) { err = e.Remove(ctx, tex(`event.data.another == "no"`)) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) + require.Equal(t, 1, e.Len()) // The first expr is remaining. - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 1, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 1, e.FastLen()) // Matching this expr should now fail. eval, count, err = e.Evaluate(ctx, map[string]any{ @@ -681,8 +684,8 @@ func TestAddRemove(t *testing.T) { err = e.Remove(ctx, tex(`event.data.another == "i'm not here"`)) require.Error(t, ErrEvaluableNotFound, err) require.Equal(t, 1, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 1, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 1, e.FastLen()) }) t.Run("With a non-aggregateable expression due to inequality/GTE on strings", func(t *testing.T) { @@ -690,40 +693,40 @@ func TestAddRemove(t *testing.T) { ok, err := e.Add(ctx, loader.AddEval(tex(`event.data.foo != "no"`))) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(0)) require.Equal(t, 1, e.Len()) - require.Equal(t, 1, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 1, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) // Add the same expression again. ok, err = e.Add(ctx, loader.AddEval(tex(`event.data.foo >= "no"`))) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(0)) require.Equal(t, 2, e.Len()) - require.Equal(t, 2, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 2, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) // Add a new expression ok, err = e.Add(ctx, loader.AddEval(tex(`event.data.another < "no"`))) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(0)) require.Equal(t, 3, e.Len()) - require.Equal(t, 3, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 3, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) // And remove. err = e.Remove(ctx, loader.AddEval(tex(`event.data.another < "no"`))) require.NoError(t, err) + require.Equal(t, 2, e.SlowLen()) require.Equal(t, 2, e.Len()) - require.Equal(t, 2, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 0, e.FastLen()) // And yeet out another non-existent expression err = e.Remove(ctx, loader.AddEval(tex(`event.data.another != "i'm not here" && a != "b"`))) require.Error(t, ErrEvaluableNotFound, err) require.Equal(t, 2, e.Len()) - require.Equal(t, 2, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 2, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) }) } @@ -741,10 +744,10 @@ func TestEmptyExpressions(t *testing.T) { t.Run("Adding an empty expression succeeds", func(t *testing.T) { ok, err := e.Add(ctx, empty) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(-1)) // TODO Check this failing case require.Equal(t, 1, e.Len()) - require.Equal(t, 1, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 1, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) }) t.Run("Empty expressions always match", func(t *testing.T) { @@ -764,8 +767,8 @@ func TestEmptyExpressions(t *testing.T) { err := e.Remove(ctx, empty) require.NoError(t, err) require.Equal(t, 0, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) }) } @@ -783,15 +786,15 @@ func TestEvaluate_Null(t *testing.T) { t.Run("Adding a `null` check succeeds and is aggregateable", func(t *testing.T) { ok, err := e.Add(ctx, notNull) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) ok, err = e.Add(ctx, isNull) require.NoError(t, err) - require.True(t, ok) + require.Greater(t, ok, float64(0)) require.Equal(t, 2, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 2, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 2, e.FastLen()) }) t.Run("Not null checks succeed", func(t *testing.T) { @@ -825,8 +828,8 @@ func TestEvaluate_Null(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 1, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 1, e.FastLen()) // We should still match on `isNull` t.Run("Is null checks succeed", func(t *testing.T) { @@ -845,8 +848,8 @@ func TestEvaluate_Null(t *testing.T) { err = e.Remove(ctx, isNull) require.NoError(t, err) require.Equal(t, 0, e.Len()) - require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 0, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) // We should no longer match on `isNull` t.Run("Is null checks succeed", func(t *testing.T) { @@ -867,11 +870,11 @@ func TestEvaluate_Null(t *testing.T) { idents := loader.AddEval(tex("event.data.a == event.data.b")) ok, err := e.Add(ctx, idents) require.NoError(t, err) - require.False(t, ok) + require.Equal(t, ok, float64(0)) require.Equal(t, 1, e.Len()) - require.Equal(t, 1, e.ConstantLen()) - require.Equal(t, 0, e.AggregateableLen()) + require.Equal(t, 1, e.SlowLen()) + require.Equal(t, 0, e.FastLen()) eval, count, err := e.Evaluate(ctx, map[string]any{ "event": map[string]any{