diff --git a/expr.go b/expr.go index ded3f27..6492201 100644 --- a/expr.go +++ b/expr.go @@ -11,6 +11,10 @@ import ( "github.com/ohler55/ojg/jp" ) +var ( + ErrEvaluableNotFound = fmt.Errorf("Evaluable instance not found in aggregator") +) + // errTreeUnimplemented is used while we develop the aggregate tree library when trees // are not yet implemented. var errTreeUnimplemented = fmt.Errorf("tree type unimplemented") @@ -238,12 +242,9 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { return false, err } - // NOTE: When modifying, ensure that Remove() is updated. We should reconcile - // the core loops to use the same code. - aggregateable := true for _, g := range parsed.RootGroups() { - ok, err := a.addGroup(ctx, g, parsed) + ok, err := a.iterGroup(ctx, g, parsed, a.addNode) if err != nil { return false, err } @@ -264,7 +265,51 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { return aggregateable, nil } -func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExpression) (bool, error) { +func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { + // parse the expression using our tree parser. + parsed, err := a.parser.Parse(ctx, eval) + if err != nil { + return err + } + + aggregateable := true + for _, g := range parsed.RootGroups() { + ok, err := a.iterGroup(ctx, g, parsed, a.removeNode) + if err == ErrExpressionPartNotFound { + return ErrEvaluableNotFound + } + if err != nil { + return err + } + if !ok && aggregateable { + // Find the index of the evaluable in constants and yank out. + idx := -1 + for n, item := range a.constants { + if item.Evaluable.Identifier() == eval.Identifier() { + idx = n + break + } + } + + if idx == -1 { + return ErrEvaluableNotFound + } + + a.lock.Lock() + a.constants = append(a.constants[:idx], a.constants[idx+1:]...) + a.lock.Unlock() + aggregateable = false + } + } + + if aggregateable { + atomic.AddInt32(&a.len, -1) + } + + return nil +} + +func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedExpression, op nodeOp) (bool, error) { if len(node.Ors) > 0 { // If there are additional branches, don't bother to add this to the aggregate tree. // Mark this as a non-exhaustive addition and skip immediately. @@ -305,7 +350,7 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp // ident/variable. Using the group ID, we can see if we've matched N necessary // items from the same identifier. If so, the evaluation is true. for _, n := range all { - err := a.addNode(ctx, n, parsed) + err := op(ctx, n, parsed) if err == errTreeUnimplemented { return false, nil } @@ -317,6 +362,9 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp return true, nil } +// nodeOp represents an op eg. addNode or removeNode +type nodeOp func(ctx context.Context, n *Node, parsed *ParsedExpression) error + func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpression) error { // Don't allow anything to update in parallel. This enrues that Add() can be called // concurrently. @@ -344,15 +392,31 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpress return errTreeUnimplemented } -func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { - // parse the expression using our tree parser. - parsed, err := a.parser.Parse(ctx, eval) - _ = parsed - if err != nil { - return err - } +func (a *aggregator) removeNode(ctx context.Context, n *Node, parsed *ParsedExpression) error { + // Don't allow anything to update in parallel. This enrues that Add() can be called + // concurrently. + a.lock.Lock() + defer a.lock.Unlock() - return fmt.Errorf("not implemented") + // Each node is aggregateable, so add this to the map for fast filtering. + switch n.Predicate.TreeType() { + case TreeTypeART: + tree, ok := a.artIdents[n.Predicate.Ident] + if !ok { + tree = newArtTree() + } + err := tree.Remove(ctx, ExpressionPart{ + GroupID: n.GroupID, + Predicate: *n.Predicate, + Parsed: parsed, + }) + if err != nil { + return err + } + a.artIdents[n.Predicate.Ident] = tree + return nil + } + return errTreeUnimplemented } func isAggregateable(n *Node) bool { diff --git a/expr_test.go b/expr_test.go index a62a6f5..a2c3f00 100644 --- a/expr_test.go +++ b/expr_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "math/rand" + "strings" "sync" "testing" "time" @@ -39,20 +40,7 @@ func evaluate(b *testing.B, i int, parser TreeParser) error { expected := tex(`event.data.account_id == "yes" && event.data.match == "true"`) _, _ = e.Add(ctx, expected) - wg := sync.WaitGroup{} - // Insert N random matches. - for n := 0; n < i; n++ { - wg.Add(1) - //nolint:all - go func() { - defer wg.Done() - byt := make([]byte, 8) - _, _ = rand.Read(byt) - str := hex.EncodeToString(byt) - _, _ = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) - }() - } - wg.Wait() + addOtherExpressions(i, e) b.StartTimer() @@ -83,22 +71,7 @@ func TestEvaluate(t *testing.T) { n := 100_000 - wg := sync.WaitGroup{} - for i := 0; i < n; i++ { - wg.Add(1) - //nolint:all - go func() { - defer wg.Done() - byt := make([]byte, 8) - _, err := rand.Read(byt) - require.NoError(t, err) - str := hex.EncodeToString(byt) - - _, err = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) - require.NoError(t, err) - }() - } - wg.Wait() + addOtherExpressions(n, e) require.EqualValues(t, n+1, e.Len()) @@ -151,19 +124,7 @@ func TestEvaluate_Concurrently(t *testing.T) { _, err = e.Add(ctx, expected) require.NoError(t, err) - go func() { - for i := 0; i < 100_000; i++ { - //nolint:all - go func() { - byt := make([]byte, 8) - _, err := rand.Read(byt) - require.NoError(t, err) - str := hex.EncodeToString(byt) - _, err = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) - require.NoError(t, err) - }() - } - }() + addOtherExpressions(100_000, e) t.Run("It matches items", func(t *testing.T) { wg := sync.WaitGroup{} @@ -199,24 +160,6 @@ func TestEvaluate_ArrayIndexes(t *testing.T) { _, err = e.Add(ctx, expected) require.NoError(t, err) - n := 100_000 - wg := sync.WaitGroup{} - for i := 0; i < n; i++ { - wg.Add(1) - //nolint:all - go func() { - defer wg.Done() - byt := make([]byte, 8) - _, err := rand.Read(byt) - require.NoError(t, err) - str := hex.EncodeToString(byt) - - _, err = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) - require.NoError(t, err) - }() - } - wg.Wait() - t.Run("It doesn't return if arrays contain non-matching data", func(t *testing.T) { pre := time.Now() evals, matched, err := e.Evaluate(ctx, map[string]any{ @@ -363,7 +306,7 @@ func TestAggregateMatch(t *testing.T) { }) } -func TestAdd(t *testing.T) { +func TestAddRemove(t *testing.T) { ctx := context.Background() parser, err := newParser() require.NoError(t, err) @@ -371,7 +314,9 @@ func TestAdd(t *testing.T) { t.Run("With a basic aggregateable expression", func(t *testing.T) { e := NewAggregateEvaluator(parser, testBoolEvaluator) - ok, err := e.Add(ctx, tex(`event.data.foo == "yes"`)) + firstExpr := tex(`event.data.foo == "yes"`, "first-id") + + ok, err := e.Add(ctx, firstExpr) require.NoError(t, err) require.True(t, ok) require.Equal(t, 1, e.Len()) @@ -379,20 +324,87 @@ func TestAdd(t *testing.T) { require.Equal(t, 1, e.AggregateableLen()) // Add the same expression again. - ok, err = e.Add(ctx, tex(`event.data.foo == "yes"`)) + ok, err = e.Add(ctx, tex(`event.data.foo == "yes"`, "second-id")) require.NoError(t, err) require.True(t, ok) require.Equal(t, 2, e.Len()) require.Equal(t, 0, e.ConstantLen()) require.Equal(t, 2, e.AggregateableLen()) + t.Run("It removes duplicate expressions with different IDs", func(t *testing.T) { + // Matching this expr should work before removal. + eval, count, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{"foo": "yes"}, + }, + }) + require.NoError(t, err) + require.EqualValues(t, 2, len(eval)) + require.EqualValues(t, 2, count) + + err = e.Remove(ctx, tex(`event.data.foo == "yes"`, "second-id")) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 1, e.Len()) + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 1, e.AggregateableLen()) + + // Matching this expr should now fail. + eval, count, err = e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{"foo": "yes"}, + }, + }) + require.NoError(t, err) + require.EqualValues(t, 1, len(eval)) + require.EqualValues(t, 1, count) + require.EqualValues(t, firstExpr.Identifier(), eval[0].Identifier()) + }) + // Add a new expression ok, err = e.Add(ctx, tex(`event.data.another == "no"`)) require.NoError(t, err) require.True(t, ok) - require.Equal(t, 3, e.Len()) + require.Equal(t, 2, e.Len()) + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 2, e.AggregateableLen()) + + // Remove all expressions + t.Run("It removes an aggregateable expression", func(t *testing.T) { + // Matching this expr should work before removal. + eval, count, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{"another": "no"}, + }, + }) + require.NoError(t, err) + require.EqualValues(t, 1, len(eval)) + require.EqualValues(t, 1, count) + + err = e.Remove(ctx, tex(`event.data.another == "no"`)) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 1, e.Len()) // The first expr is remaining. + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 1, e.AggregateableLen()) + + // Matching this expr should now fail. + eval, count, err = e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{"another": "no"}, + }, + }) + require.NoError(t, err) + require.Empty(t, eval) + require.EqualValues(t, 0, count) + }) + + // And yeet a non-existent aggregateable expr. + err = e.Remove(ctx, tex(`event.data.another == "i'm not here"`)) + require.Error(t, ErrEvaluableNotFound, err) + require.Equal(t, 1, e.Len()) require.Equal(t, 0, e.ConstantLen()) - require.Equal(t, 3, e.AggregateableLen()) + require.Equal(t, 1, e.AggregateableLen()) }) t.Run("With a non-aggregateable expression due to inequality/GTE on strings", func(t *testing.T) { @@ -420,14 +432,38 @@ func TestAdd(t *testing.T) { require.Equal(t, 3, e.Len()) require.Equal(t, 3, e.ConstantLen()) require.Equal(t, 0, e.AggregateableLen()) + + // And remove. + err = e.Remove(ctx, tex(`event.data.another < "no"`)) + require.NoError(t, err) + require.Equal(t, 2, e.Len()) + require.Equal(t, 2, e.ConstantLen()) + require.Equal(t, 0, e.AggregateableLen()) + + // And yeet out another non-existent expression + err = e.Remove(ctx, tex(`event.data.another != "i'm not here" && a != "b"`)) + require.Error(t, ErrEvaluableNotFound, err) + require.Equal(t, 2, e.Len()) + require.Equal(t, 2, e.ConstantLen()) + require.Equal(t, 0, e.AggregateableLen()) }) } // tex represents a test Evaluable expression -type tex string +func tex(expr string, ids ...string) Evaluable { + return testEvaluable{ + expr: expr, + id: strings.Join(ids, ","), + } +} -func (e tex) Expression() string { return string(e) } -func (e tex) Identifier() string { return string(e) } +type testEvaluable struct { + expr string + id string +} + +func (e testEvaluable) Expression() string { return e.expr } +func (e testEvaluable) Identifier() string { return e.expr + e.id } func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) (bool, error) { env, _ := cel.NewEnv( @@ -465,3 +501,27 @@ func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) ( } return result.Value().(bool), nil } + +func addOtherExpressions(n int, e AggregateEvaluator) { + ctx := context.Background() + wg := sync.WaitGroup{} + for i := 0; i < n; i++ { + wg.Add(1) + //nolint:all + go func() { + defer wg.Done() + byt := make([]byte, 8) + _, err := rand.Read(byt) + if err != nil { + panic(err) + } + str := hex.EncodeToString(byt) + + _, err = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) + if err != nil { + panic(err) + } + }() + } + wg.Wait() +} diff --git a/lift.go b/lift.go index cf15d9d..14bde98 100644 --- a/lift.go +++ b/lift.go @@ -15,7 +15,12 @@ const ( var ( // replace is truly hack city. these are 20 variable names for values that are // lifted out of expressions via liftLiterals. - replace = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t"} + replace = []string{ + "a", "b", "c", "d", "e", + "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o", + "p", "q", "r", "s", "t", + } ) // LiftedArgs represents a set of variables that have been lifted from expressions and diff --git a/parser.go b/parser.go index e5f0ff6..eba2418 100644 --- a/parser.go +++ b/parser.go @@ -56,6 +56,9 @@ func NewTreeParser(ep CELParser) (TreeParser, error) { type parser struct { ep CELParser + // rander is a random reader set during testing. it is never used outside + // of the test package during Parse. Instead, a new deterministic random + // reader is generated from the Evaluable identifier. rander RandomReader } @@ -65,7 +68,9 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, return nil, issues.Err() } - if p.rander == nil { + r := p.rander + + if r == nil { // Create a new deterministic random reader based off of the evaluable's identifier. // This means that every time we parse an expression with the given identifier, the // group IDs will be deterministic as the randomness is sourced from the ID. @@ -73,7 +78,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, // We only overwrite this if rander is not nil so that we can inject rander during tests. digest := sha256.Sum256([]byte(eval.Identifier())) seed := int64(binary.NativeEndian.Uint64(digest[:8])) - p.rander = rand.New(rand.NewSource(seed)).Read + r = rand.New(rand.NewSource(seed)).Read } node := newNode() @@ -83,7 +88,7 @@ func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, }, node, vars, - p.rander, + r, ) if err != nil { return nil, err @@ -446,7 +451,12 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs, rand RandomReader) ([] } parent.GroupID = newGroupIDWithReader(uint16(total), rand) + // For each sub-group, add the same group IDs to children if there's no nesting. + // + // We do this so that the parent node which contains all ANDs can correctly set + // the same group ID for all child predicates. This is necessasry; if you compare + // A && B && C, we want all of A/B/C to share the same group ID for n, item := range parent.Ands { if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil { item.GroupID = parent.GroupID