Skip to content

Commit

Permalink
Null aggregation (#12)
Browse files Browse the repository at this point in the history
* Parse and check null equality in expressions

* Add aggregateable null checks

* Make tree lookups concurrent

* Add null removal
  • Loading branch information
tonyhb authored Jan 8, 2024
1 parent 83deef3 commit 446dbfb
Show file tree
Hide file tree
Showing 15 changed files with 600 additions and 76 deletions.
165 changes: 131 additions & 34 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/google/cel-go/common/operators"
"github.com/ohler55/ojg/jp"
"golang.org/x/sync/errgroup"
)

var (
Expand Down Expand Up @@ -66,10 +67,11 @@ func NewAggregateEvaluator(
eval ExpressionEvaluator,
) AggregateEvaluator {
return &aggregator{
eval: eval,
parser: parser,
artIdents: map[string]PredicateTree{},
lock: &sync.RWMutex{},
eval: eval,
parser: parser,
artIdents: map[string]PredicateTree{},
nullLookups: map[string]PredicateTree{},
lock: &sync.RWMutex{},
}
}

Expand All @@ -94,8 +96,9 @@ type aggregator struct {
eval ExpressionEvaluator
parser TreeParser

artIdents map[string]PredicateTree
lock *sync.RWMutex
artIdents map[string]PredicateTree
nullLookups map[string]PredicateTree
lock *sync.RWMutex

len int32

Expand Down Expand Up @@ -155,7 +158,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
err = errors.Join(err, merr)
}

// TODO: Each match here is a potential success. When other trees and operators which are walkable
// Each match here is a potential success. When other trees and operators which are walkable
// are added (eg. >= operators on strings), ensure that we find the correct number of matches
// for each group ID and then skip evaluating expressions if the number of matches is <= the group
// ID's length.
Expand Down Expand Up @@ -186,51 +189,96 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
return result, matched, nil
}

// AggregateMatch attempts to match incoming data to all PredicateTrees, resulting in a selection
// of parts of an expression that have matched.
func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([]ExpressionPart, error) {
result := []ExpressionPart{}

a.lock.RLock()
defer a.lock.RUnlock()

// Store the number of times each GroupID has found a match. We need at least
// as many matches as stored in the group ID to consider the match.
// Each match here is a potential success. Ensure that we find the correct number of matches
// for each group ID and then skip evaluating expressions if the number of matches is <= the group
// ID's length. For example, (A && B && C) is a single group ID and must have a count >= 3,
// else we know a required comparason did not match.
//
// Note that having a count >= the group ID value does not guarantee that the expression is valid.
counts := map[groupID]int{}
// Store all expression parts per group ID for returning.
found := map[groupID][]ExpressionPart{}
// protect the above locks with a map.
lock := &sync.Mutex{}
// run lookups concurrently.
eg := errgroup.Group{}

add := func(all []ExpressionPart) {
// This is called concurrently, so don't mess with maps in goroutines
lock.Lock()
defer lock.Unlock()

for _, eval := range all {
counts[eval.GroupID] += 1
if _, ok := found[eval.GroupID]; !ok {
found[eval.GroupID] = []ExpressionPart{}
}
found[eval.GroupID] = append(found[eval.GroupID], eval)
}
}

// Iterate through all known variables/idents in the aggregate tree to see if
// the data has those keys set. If so, we can immediately evaluate the data with
// the tree.
//
// TODO: we should iterate through the expression in a top-down order, ensuring that if
// any of the top groups fail to match we quit early.
for k, tree := range a.artIdents {
x, err := jp.ParseString(k)
if err != nil {
return nil, err
}
res := x.Get(data)
if len(res) != 1 {
continue
}
for n, item := range a.artIdents {
tree := item
path := n
eg.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
}
res := x.Get(data)
if len(res) != 1 {
return nil
}

switch cast := res[0].(type) {
case string:
all, ok := tree.Search(ctx, cast)
cast, ok := res[0].(string)
if !ok {
continue
// This isn't a string, so we can't compare within the radix tree.
return nil
}

for _, eval := range all.Evals {
counts[eval.GroupID] += 1
if _, ok := found[eval.GroupID]; !ok {
found[eval.GroupID] = []ExpressionPart{}
}
found[eval.GroupID] = append(found[eval.GroupID], eval)
add(tree.Search(ctx, path, cast))
return nil
})
}

// Match on nulls.
for n, item := range a.nullLookups {
tree := item
path := n
eg.Go(func() error {
x, err := jp.ParseString(path)
if err != nil {
return err
}
default:
continue
}

res := x.Get(data)
if len(res) == 0 {
// This isn't present, which matches null in our overloads. Set the
// value to nil.
res = []any{nil}
}
// This matches null, nil (as null), and any non-null items.
add(tree.Search(ctx, path, res[0]))
return nil
})
}

if err := eg.Wait(); err != nil {
return nil, err
}

for k, count := range counts {
Expand Down Expand Up @@ -393,6 +441,21 @@ func (a *aggregator) iterGroup(ctx context.Context, node *Node, parsed *ParsedEx
return true, nil
}

func treeType(p Predicate) TreeType {
// switch on type of literal AND operator type. int64/float64 literals require
// btrees, texts require ARTs.
switch p.Literal.(type) {
case string:
return TreeTypeART
case int64, float64:
return TreeTypeBTree
case nil:
return TreeTypeNullMatch
default:
return TreeTypeNone
}
}

// nodeOp represents an op eg. addNode or removeNode
type nodeOp func(ctx context.Context, n *Node, parsed *ParsedExpression) error

Expand All @@ -403,7 +466,7 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpress
defer a.lock.Unlock()

// Each node is aggregateable, so add this to the map for fast filtering.
switch n.Predicate.TreeType() {
switch treeType(*n.Predicate) {
case TreeTypeART:
tree, ok := a.artIdents[n.Predicate.Ident]
if !ok {
Expand All @@ -419,6 +482,21 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpress
}
a.artIdents[n.Predicate.Ident] = tree
return nil
case TreeTypeNullMatch:
tree, ok := a.nullLookups[n.Predicate.Ident]
if !ok {
tree = newNullMatcher()
}
err := tree.Add(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: *n.Predicate,
Parsed: parsed,
})
if err != nil {
return err
}
a.nullLookups[n.Predicate.Ident] = tree
return nil
}
return errTreeUnimplemented
}
Expand All @@ -430,11 +508,11 @@ func (a *aggregator) removeNode(ctx context.Context, n *Node, parsed *ParsedExpr
defer a.lock.Unlock()

// Each node is aggregateable, so add this to the map for fast filtering.
switch n.Predicate.TreeType() {
switch treeType(*n.Predicate) {
case TreeTypeART:
tree, ok := a.artIdents[n.Predicate.Ident]
if !ok {
tree = newArtTree()
return ErrExpressionPartNotFound
}
err := tree.Remove(ctx, ExpressionPart{
GroupID: n.GroupID,
Expand All @@ -446,6 +524,21 @@ func (a *aggregator) removeNode(ctx context.Context, n *Node, parsed *ParsedExpr
}
a.artIdents[n.Predicate.Ident] = tree
return nil
case TreeTypeNullMatch:
tree, ok := a.nullLookups[n.Predicate.Ident]
if !ok {
return ErrExpressionPartNotFound
}
err := tree.Remove(ctx, ExpressionPart{
GroupID: n.GroupID,
Predicate: *n.Predicate,
Parsed: parsed,
})
if err != nil {
return err
}
a.nullLookups[n.Predicate.Ident] = tree
return nil
}
return errTreeUnimplemented
}
Expand Down Expand Up @@ -479,6 +572,10 @@ func isAggregateable(n *Node) bool {
case int64, float64:
// TODO: Add binary tree matching for ints/floats
return false
case nil:
// This is null, which is supported and a simple lookup to check
// if the event's key in question is present and is not nil.
return true
default:
return false
}
Expand Down
94 changes: 93 additions & 1 deletion expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func evaluate(b *testing.B, i int, parser TreeParser) error {
return nil
}

func TestEvaluate(t *testing.T) {
func TestEvaluate_Strings(t *testing.T) {
ctx := context.Background()
parser := NewTreeParser(NewCachingParser(newEnv(), nil))
e := NewAggregateEvaluator(parser, testBoolEvaluator)
Expand Down Expand Up @@ -479,7 +479,99 @@ func TestEmptyExpressions(t *testing.T) {
require.Equal(t, 0, e.ConstantLen())
require.Equal(t, 0, e.AggregateableLen())
})
}

func TestEvaluate_Null(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
require.NoError(t, err)

e := NewAggregateEvaluator(parser, testBoolEvaluator)

notNull := tex(`event.ts != null`, "id-1")
isNull := tex(`event.ts == null`, "id-2")

t.Run("Adding a `null` check succeeds and is aggregateable", func(t *testing.T) {
ok, err := e.Add(ctx, notNull)
require.NoError(t, err)
require.True(t, ok)

ok, err = e.Add(ctx, isNull)
require.NoError(t, err)
require.True(t, ok)

require.Equal(t, 2, e.Len())
require.Equal(t, 0, e.ConstantLen())
require.Equal(t, 2, e.AggregateableLen())
})

t.Run("Not null checks succeed", func(t *testing.T) {
// Matching this expr should now fail.
eval, count, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"ts": time.Now().UnixMilli(),
},
})
require.NoError(t, err)
require.EqualValues(t, 1, len(eval))
require.EqualValues(t, 1, count)
require.EqualValues(t, notNull, eval[0])
})

t.Run("Is null checks succeed", func(t *testing.T) {
// Matching this expr should work, as "ts" is nil
eval, count, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"ts": nil,
},
})
require.NoError(t, err)
require.EqualValues(t, 1, len(eval))
require.EqualValues(t, 1, count)
require.EqualValues(t, isNull, eval[0])
})

t.Run("It removes null checks", func(t *testing.T) {
err := e.Remove(ctx, notNull)
require.NoError(t, err)

require.Equal(t, 1, e.Len())
require.Equal(t, 0, e.ConstantLen())
require.Equal(t, 1, e.AggregateableLen())

// We should still match on `isNull`
t.Run("Is null checks succeed", func(t *testing.T) {
// Matching this expr should work, as "ts" is nil
eval, count, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"ts": nil,
},
})
require.NoError(t, err)
require.EqualValues(t, 1, len(eval))
require.EqualValues(t, 1, count)
require.EqualValues(t, isNull, eval[0])
})

err = e.Remove(ctx, isNull)
require.NoError(t, err)
require.Equal(t, 0, e.Len())
require.Equal(t, 0, e.ConstantLen())
require.Equal(t, 0, e.AggregateableLen())

// We should no longer match on `isNull`
t.Run("Is null checks succeed", func(t *testing.T) {
// Matching this expr should work, as "ts" is nil
eval, count, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"ts": nil,
},
})
require.NoError(t, err)
require.EqualValues(t, 0, len(eval))
require.EqualValues(t, 0, count)
})
})
}

// tex represents a test Evaluable expression
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ require (
github.com/ohler55/ojg v1.21.0
github.com/plar/go-adaptive-radix-tree v1.0.5
github.com/stretchr/testify v1.8.4
golang.org/x/sync v0.6.0
google.golang.org/protobuf v1.31.0
)

require (
Expand All @@ -19,6 +21,5 @@ require (
golang.org/x/text v0.9.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlV
github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM=
golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU=
golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
Loading

0 comments on commit 446dbfb

Please sign in to comment.