diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml new file mode 100644 index 0000000..aebd6ff --- /dev/null +++ b/.github/workflows/go.yaml @@ -0,0 +1,37 @@ +name: Go + +on: + push: + branches: [main] + pull_request: + +jobs: + golangci: + name: lint + strategy: + matrix: + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Lint + run: | + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.55.1 + ./bin/golangci-lint run --verbose + test-linux-race: + strategy: + matrix: + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: '1.21' + - name: Test + run: go test ./... -v -count=1 diff --git a/caching_parser.go b/caching_parser.go new file mode 100644 index 0000000..fbc7537 --- /dev/null +++ b/caching_parser.go @@ -0,0 +1,113 @@ +package expr + +import ( + "regexp" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/google/cel-go/cel" + // "github.com/karlseguin/ccache/v2" +) + +var ( + doubleQuoteMatch *regexp.Regexp + replace = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} +) + +func init() { + doubleQuoteMatch = regexp.MustCompile(`"[^"]*"`) +} + +// NewCachingParser returns a CELParser which lifts quoted literals out of the expression +// as variables and uses caching to cache expression parsing, resulting in improved +// performance when parsing expressions. +func NewCachingParser(env *cel.Env) CELParser { + return &cachingParser{ + env: env, + } +} + +type cachingParser struct { + // cache is a global cache of precompiled expressions. + // cache *ccache.Cache + stupidNoInternetCache sync.Map + + env *cel.Env + + hits int64 + misses int64 +} + +// liftLiterals lifts quoted literals into variables, allowing us to normalize +// expressions to increase cache hit rates. +func liftLiterals(expr string) (string, map[string]any) { + // TODO: Optimize this please. Use strconv.Unquote as the basis, and perform + // searches across each index quotes. + + // If this contains an escape sequence (eg. `\` or `\'`), skip the lifting + // of literals out of the expression. + if strings.Contains(expr, `\"`) || strings.Contains(expr, `\'`) { + return expr, nil + } + + var ( + counter int + vars = map[string]any{} + ) + + rewrite := func(str string) string { + if counter > len(replace) { + return str + } + + idx := replace[counter] + if val, err := strconv.Unquote(str); err == nil { + str = val + } + vars[idx] = str + + counter++ + return VarPrefix + idx + } + + expr = doubleQuoteMatch.ReplaceAllStringFunc(expr, rewrite) + return expr, vars +} + +func (c *cachingParser) Parse(expr string) (*cel.Ast, *cel.Issues, map[string]any) { + expr, vars := liftLiterals(expr) + + // TODO: ccache, when I have internet. + if cached, ok := c.stupidNoInternetCache.Load(expr); ok { + p := cached.(ParsedCelExpr) + atomic.AddInt64(&c.hits, 1) + return p.AST, p.Issues, vars + } + + ast, issues := c.env.Parse(expr) + + c.stupidNoInternetCache.Store(expr, ParsedCelExpr{ + Expr: expr, + AST: ast, + Issues: issues, + }) + + atomic.AddInt64(&c.misses, 1) + return ast, issues, vars +} + +func (c *cachingParser) Hits() int64 { + return atomic.LoadInt64(&c.hits) +} + +func (c *cachingParser) Misses() int64 { + return atomic.LoadInt64(&c.misses) +} + +type ParsedCelExpr struct { + Expr string + AST *cel.Ast + Issues *cel.Issues +} diff --git a/caching_parser_test.go b/caching_parser_test.go new file mode 100644 index 0000000..eb8528a --- /dev/null +++ b/caching_parser_test.go @@ -0,0 +1,140 @@ +package expr + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" +) + +func TestCachingParser_CachesSame(t *testing.T) { + c := cachingParser{env: newEnv()} + + a := `event.data.a == "cache"` + b := `event.data.b == "cache"` + + var ( + prevAST *cel.Ast + prevIssues *cel.Issues + prevVars map[string]any + ) + + t.Run("With an uncached expression", func(t *testing.T) { + prevAST, prevIssues, prevVars = c.Parse(a) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + require.NotNil(t, prevVars) + require.EqualValues(t, 0, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With a cached expression", func(t *testing.T) { + ast, issues, vars := c.Parse(a) + require.NotNil(t, ast) + require.Nil(t, issues) + + require.Equal(t, prevAST, ast) + require.Equal(t, prevIssues, issues) + require.Equal(t, prevVars, vars) + + require.EqualValues(t, 1, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With another uncached expression", func(t *testing.T) { + prevAST, prevIssues, prevVars = c.Parse(b) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + // This misses the cache, as the vars have changed - not the + // literals. + require.EqualValues(t, 1, c.Hits()) + require.EqualValues(t, 2, c.Misses()) + }) +} + +func TestCachingParser_CacheIgnoreLiterals_Unescaped(t *testing.T) { + c := cachingParser{env: newEnv()} + + a := `event.data.a == "literal-a" && event.data.b == "yes-1"` + b := `event.data.a == "literal-b" && event.data.b == "yes-2"` + + var ( + prevAST *cel.Ast + prevIssues *cel.Issues + prevVars map[string]any + ) + + t.Run("With an uncached expression", func(t *testing.T) { + prevAST, prevIssues, prevVars = c.Parse(a) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + require.EqualValues(t, 0, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With a cached expression", func(t *testing.T) { + ast, issues, vars := c.Parse(a) + require.NotNil(t, ast) + require.Nil(t, issues) + + require.Equal(t, prevAST, ast) + require.Equal(t, prevIssues, issues) + require.Equal(t, prevVars, vars) + + require.EqualValues(t, 1, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With a cached expression having different literals ONLY", func(t *testing.T) { + prevAST, prevIssues, _ = c.Parse(b) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + // This misses the cache. + require.EqualValues(t, 2, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) +} + +/* +func TestCachingParser_CacheIgnoreLiterals_Escaped(t *testing.T) { + return + c := cachingParser{env: newEnv()} + + a := `event.data.a == "literal\"-a" && event.data.b == "yes"` + b := `event.data.a == "literal\"-b" && event.data.b == "yes"` + + var ( + prevAST *cel.Ast + prevIssues *cel.Issues + ) + + t.Run("With an uncached expression", func(t *testing.T) { + prevAST, prevIssues = c.Parse(a) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + require.EqualValues(t, 0, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With a cached expression", func(t *testing.T) { + ast, issues := c.Parse(a) + require.NotNil(t, ast) + require.Nil(t, issues) + + require.Equal(t, prevAST, ast) + require.Equal(t, prevIssues, issues) + + require.EqualValues(t, 1, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) + + t.Run("With a cached expression having different literals ONLY", func(t *testing.T) { + prevAST, prevIssues = c.Parse(b) + require.NotNil(t, prevAST) + require.Nil(t, prevIssues) + // This misses the cache. + require.EqualValues(t, 2, c.Hits()) + require.EqualValues(t, 1, c.Misses()) + }) +} +*/ diff --git a/expr.go b/expr.go index 1f74c7f..2b0b8b8 100644 --- a/expr.go +++ b/expr.go @@ -2,27 +2,71 @@ package expr import ( "context" + "errors" "fmt" + "sync" + "sync/atomic" + + "github.com/google/cel-go/common/operators" ) +// errTreeUnimplemented is used while we develop the aggregate tree library when trees +// are not yet implemented. +var errTreeUnimplemented = fmt.Errorf("tree type unimplemented") + +// ExpressionEvaluator is a function which evalues an expression given input data, returning +// a boolean and error. +type ExpressionEvaluator func(ctx context.Context, e Evaluable, input map[string]any) (bool, error) + // AggregateEvaluator represents a group of expressions that must be evaluated for a single // event received. +// +// An AggregateEvaluator instance exists for every event name being matched. type AggregateEvaluator interface { - // Add adds an expression to the tree evaluator - Add(ctx context.Context, eval Evaluable) error + // Add adds an expression to the tree evaluator. This returns true + // if the expression is aggregateable, or false if the expression will be + // evaluated each time an event is received. + Add(ctx context.Context, eval Evaluable) (bool, error) + // Remove removes an expression from the aggregate evaluator Remove(ctx context.Context, eval Evaluable) error + // AggregateMatch returns all expression parts which are evaluable given the input data. + // + // It does this by iterating through the data, + AggregateMatch(ctx context.Context, data map[string]any) ([]ExpressionPart, error) + // Evaluate checks input data against all exrpesssions in the aggregate in an optimal // manner, only evaluating expressions when necessary (based off of tree matching). // - // This returns a list of evaluable expressions that match the given input. - Evaluate(ctx context.Context, data map[string]any) ([]Evaluable, error) + // Note that any expressions added that cannot be evaluated optimally by trees + // are evaluated every time this function is called. + // + // Evaluate returns all matching Evaluables, plus the total number of evaluations + // executed. + Evaluate(ctx context.Context, data map[string]any) ([]Evaluable, int32, error) + + // Len returns the total number of aggregateable and constantly matched expressions + // stored in the evaluator. + Len() int + + // AggregateableLen returns the number of expressions being matched by aggregated trees. + AggregateableLen() int + + // ConstantLen returns the total number of expressions that must constantly + // be matched due to non-aggregateable clauses in their expressions. + ConstantLen() int } -func NewAggregateEvaluator(parser TreeParser) AggregateEvaluator { +func NewAggregateEvaluator( + parser TreeParser, + eval ExpressionEvaluator, +) AggregateEvaluator { return &aggregator{ - parser: parser, + eval: eval, + parser: parser, + artIdents: map[string]PredicateTree{}, + lock: &sync.RWMutex{}, } } @@ -32,29 +76,37 @@ type Evaluable interface { } type aggregator struct { + eval ExpressionEvaluator parser TreeParser -} -func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evaluable, error) { - return nil, nil -} + artIdents map[string]PredicateTree + lock *sync.RWMutex -func (a *aggregator) Add(ctx context.Context, eval Evaluable) error { - parsed, err := a.parser.Parse(ctx, eval.Expression()) - if err != nil { - return err - } + len int32 - _ = parsed + // constants tracks evaluable instances that must always be evaluated, due to + // the expression containing non-aggregateable clauses. + constants []*ParsedExpression +} - // TODO: Iterate through each group and add the expression to tree - // types specified. +// Len returns the total number of aggregateable and constantly matched expressions +// stored in the evaluator. +func (a aggregator) Len() int { + return int(a.len) + len(a.constants) +} - // TODO: Add each group to a tree. The leaf node should point to the - // expressions that match this leaf node (pause?) +// AggregateableLen returns the number of expressions being matched by aggregated trees. +func (a aggregator) AggregateableLen() int { + return int(a.len) +} - // TODO: Pointer of checksums -> groups +// ConstantLen returns the total number of expressions that must constantly +// be matched due to non-aggregateable clauses in their expressions. +func (a aggregator) ConstantLen() int { + return len(a.constants) +} +func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evaluable, int32, error) { // on event entered: // // 1. load pauses @@ -62,10 +114,234 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) error { // 3. load nodes for pause, if none, run expression // 4. evaluate tree nodes for pause against data, if ok, run expression - fmt.Printf("%#v\n", parsed) - return fmt.Errorf("not implemented") + var ( + err error + matched = int32(0) + result = []Evaluable{} + ) + + // TODO: Concurrently match constant expressions using a semaphore for capacity. + for _, expr := range a.constants { + atomic.AddInt32(&matched, 1) + + // NOTE: We don't need to add lifted expression variables, + // because match.Parsed.Evaluable() returns the original expression + // string. + ok, evalerr := a.eval(ctx, expr.Evaluable, data) + if evalerr != nil { + err = errors.Join(err, evalerr) + continue + } + if ok { + result = append(result, expr.Evaluable) + } + } + + matches, merr := a.AggregateMatch(ctx, data) + if merr != nil { + err = errors.Join(err, merr) + } + + // TODO: Each match here is a potential success. When other trees and operators which are walkable + // are added (eg. >= operators on strings), ensure that we find the correct number of matches + // for each group ID and then skip evaluating expressions if so. + for _, match := range matches { + atomic.AddInt32(&matched, 1) + // NOTE: We don't need to add lifted expression variables, + // because match.Parsed.Evaluable() returns the original expression + // string. + ok, evalerr := a.eval(ctx, match.Parsed.Evaluable, data) + if evalerr != nil { + err = errors.Join(err, evalerr) + continue + } + if ok { + result = append(result, match.Parsed.Evaluable) + } + } + + return result, matched, nil +} + +func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([]ExpressionPart, error) { + return a.aggregateMatch(ctx, data, "") +} + +func (a *aggregator) aggregateMatch(ctx context.Context, data map[string]any, prefix string) ([]ExpressionPart, error) { + // TODO: Flip this. Instead of iterating through all fields in a potentially large input + // array, iterate through all known variables/idents in the aggregate tree to see if + // the data has those keys set. + + result := []ExpressionPart{} + for k, v := range data { + switch cast := v.(type) { + case map[string]any: + // Recurse into the map to pluck out nested idents, eg. "event.data.account.id" + evals, err := a.aggregateMatch(ctx, cast, prefix+k+".") + if err != nil { + return nil, err + } + if len(evals) > 0 { + result = append(result, evals...) + } + case string: + a.lock.RLock() + tree, ok := a.artIdents[prefix+k] + a.lock.RUnlock() + if !ok { + continue + } + found, ok := tree.Search(ctx, cast) + if !ok { + continue + } + result = append(result, found.Evals...) + default: + continue + } + + } + return result, nil +} + +// Add adds an Evaluable to the aggregate tree engine for matching. It returns +// a boolean indicating whether the expression is suitable for aggregate tree +// matching, allowing rapid exclusion of non-matching expressions. +func (a *aggregator) Add(ctx context.Context, eval Evaluable) (bool, error) { + // parse the expression using our tree parser. + parsed, err := a.parser.Parse(ctx, eval) + if err != nil { + return false, err + } + + aggregateable := true + for _, g := range parsed.RootGroups() { + ok, err := a.addGroup(ctx, g, parsed) + if err != nil { + return false, err + } + if !ok && aggregateable { + // This is the first time we're seeing a non-aggregateable + // group, so add it to the constants list. + a.lock.Lock() + a.constants = append(a.constants, parsed) + a.lock.Unlock() + aggregateable = false + } + } + + // Track the number of added expressions correctly. + if aggregateable { + atomic.AddInt32(&a.len, 1) + } + return aggregateable, nil +} + +func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExpression) (bool, error) { + if len(node.Ors) > 0 { + // If there are additional branches, don't bother to add this to the aggregate tree. + // Mark this as a non-exhaustive addition and skip immediately. + // + // TODO: Allow ORs _only if_ the ORs are not nested, eg. the ORs are basic predicate + // groups that themselves have no branches. + return false, nil + } + + // Merge all of the nodes together and check whether each node is aggregateable. + all := append(node.Ands, node) + for _, n := range all { + if !n.HasPredicate() || len(n.Ors) > 0 { + // Don't handle sub-branching for now. + return false, nil + } + if !isAggregateable(n) { + return false, nil + } + } + + // Create a new group ID which tracks the number of expressions that must match + // within this group in order for the group to pass. + // + // This includes ALL ands, plus at least one OR. + // + // When checking an incoming event, we match the event against each node's + // ident/variable. Using the group ID, we can see if we've matched N necessary + // items from the same identifier. If so, the evaluation is true. + groupID := newGroupID(uint16(len(all))) + for _, n := range all { + err := a.addNode(ctx, n, groupID, parsed) + if err == errTreeUnimplemented { + return false, nil + } + if err != nil { + return false, err + } + } + + return true, nil +} + +func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *ParsedExpression) error { + // Don't allow anything to update in parallel. + a.lock.Lock() + defer a.lock.Unlock() + + // Each node is aggregateable, so add this to the map for fast filtering. + switch n.Predicate.TreeType() { + case TreeTypeART: + tree, ok := a.artIdents[n.Predicate.Ident] + if !ok { + tree = newArtTree() + } + err := tree.Add(ctx, ExpressionPart{ + GroupID: gid, + Predicate: *n.Predicate, + Parsed: parsed, + }) + if err != nil { + return err + } + a.artIdents[n.Predicate.Ident] = tree + return nil + } + return errTreeUnimplemented } func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error { + // TODO return fmt.Errorf("not implemented") } + +func isAggregateable(n *Node) bool { + if n.Predicate == nil { + // This is a parent node. We skip aggregateable checks and only + // return false based off of predicate information. + return true + } + if n.Predicate.LiteralIdent != nil { + // We're matching idents together, so this is not aggregateable. + return false + } + + switch v := n.Predicate.Literal.(type) { + case string: + if len(v) == 0 { + return false + } + if n.Predicate.Operator == operators.NotEquals { + // NOTE: NotEquals is _not_ supported. This requires selecting all leaf nodes _except_ + // a given leaf, iterating over a tree. We may as well execute every expressiona s the difference + // is negligible. + return false + } + // Right now, we only support equality checking. + // + // TODO: Add GT(e)/LT(e) matching with tree iteration. + return n.Predicate.Operator == operators.Equals + case int64, float64: + // TODO: Add binary tree matching for ints/floats + return false + default: + return false + } +} diff --git a/expr_test.go b/expr_test.go index a4a2af3..eacf9e1 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1,6 +1,309 @@ package expr -import "testing" +import ( + "context" + "encoding/hex" + "fmt" + "math/rand" + "sync" + "testing" + "time" -func TestAggregateEvaluator(t *testing.T) { + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/require" +) + +func BenchmarkCachingEvaluate1_000(b *testing.B) { benchEval(1_000, NewCachingParser(newEnv()), b) } +func BenchmarkNonCachingEvaluate1_000(b *testing.B) { benchEval(1_000, EnvParser(newEnv()), b) } + +func benchEval(i int, p CELParser, b *testing.B) { + for n := 0; n < b.N; n++ { + parser, err := NewTreeParser(p) + if err != nil { + panic(err) + } + _ = evaluate(b, i, parser) + } +} + +func evaluate(b *testing.B, i int, parser TreeParser) error { + b.StopTimer() + ctx := context.Background() + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + // Insert the match we want to see. + expected := tex(`event.data.account_id == "yes" && event.data.match == "true"`) + _, _ = e.Add(ctx, expected) + + wg := sync.WaitGroup{} + // Insert N random matches. + for n := 0; n < i; n++ { + wg.Add(1) + //nolint:all + go func() { + defer wg.Done() + byt := make([]byte, 8) + _, _ = rand.Read(byt) + str := hex.EncodeToString(byt) + _, _ = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) + }() + } + wg.Wait() + + b.StartTimer() + + results, _, _ := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "match": "true", + }, + }, + }) + + if len(results) != 1 { + return fmt.Errorf("unexpected number of results: %d", results) + } + return nil +} + +func TestEvaluate(t *testing.T) { + ctx := context.Background() + parser, err := newParser() + require.NoError(t, err) + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + expected := tex(`event.data.account_id == "yes" && event.data.match == "true"`) + + _, err = e.Add(ctx, expected) + require.NoError(t, err) + + n := 100_000 + + wg := sync.WaitGroup{} + for i := 0; i < n; i++ { + wg.Add(1) + //nolint:all + go func() { + defer wg.Done() + byt := make([]byte, 8) + _, err := rand.Read(byt) + require.NoError(t, err) + str := hex.EncodeToString(byt) + + _, err = e.Add(ctx, tex(fmt.Sprintf(`event.data.account_id == "%s"`, str))) + require.NoError(t, err) + }() + } + wg.Wait() + + require.EqualValues(t, n+1, e.Len()) + + t.Run("It matches items", func(t *testing.T) { + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "match": "true", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms\n", total.Milliseconds()) + + require.NoError(t, err) + require.EqualValues(t, 1, matched) + require.EqualValues(t, []Evaluable{expected}, evals) + }) + + t.Run("It handles non-matching data", func(t *testing.T) { + fmt.Println("evaluating") + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "match": "no", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms\n", total.Milliseconds()) + + require.NoError(t, err) + require.EqualValues(t, 0, len(evals)) + require.EqualValues(t, 1, matched) // We still ran one expression + }) +} + +func TestAggregateMatch(t *testing.T) { + ctx := context.Background() + parser, err := newParser() + require.NoError(t, err) + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + // Add three expressions matching on "a", "b", "c" respectively. + keys := []string{"a", "b", "c"} + for _, k := range keys { + ok, err := e.Add(ctx, tex(fmt.Sprintf(`event.data.%s == "yes"`, k))) + require.True(t, ok) + require.NoError(t, err) + } + + // When passing input.data.a as "yes", we should find the match, + // as the expression's variable (event.data.a) matches the literal ("yes"). + t.Run("It matches when the ident and literal match", func(t *testing.T) { + input := map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "a": "yes", + "not-found": "no", + }, + }, + } + + matched, err := e.AggregateMatch(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 1, len(matched)) + require.EqualValues(t, + `event.data.a == "yes"`, + matched[0].Parsed.Evaluable.Expression(), + ) + }) + + // When passing input.data.b, we should match only one expression. + t.Run("It doesn't match if the literal changes", func(t *testing.T) { + input := map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "b": "no", + }, + }, + } + + matched, err := e.AggregateMatch(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 0, len(matched)) + }) + + // When passing input.data.a, we should match only one expression. + t.Run("It skips data with no expressions in the tree", func(t *testing.T) { + input := map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "none": "yes", + }, + }, + } + + matched, err := e.AggregateMatch(ctx, input) + require.NoError(t, err) + require.EqualValues(t, 0, len(matched)) + }) +} + +func TestAdd(t *testing.T) { + ctx := context.Background() + parser, err := newParser() + require.NoError(t, err) + + t.Run("With a basic aggregateable expression", func(t *testing.T) { + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + ok, err := e.Add(ctx, tex(`event.data.foo == "yes"`)) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 1, e.Len()) + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 1, e.AggregateableLen()) + + // Add the same expression again. + ok, err = e.Add(ctx, tex(`event.data.foo == "yes"`)) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 2, e.Len()) + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 2, e.AggregateableLen()) + + // Add a new expression + ok, err = e.Add(ctx, tex(`event.data.another == "no"`)) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, 3, e.Len()) + require.Equal(t, 0, e.ConstantLen()) + require.Equal(t, 3, e.AggregateableLen()) + }) + + t.Run("With a non-aggregateable expression due to inequality/GTE on strings", func(t *testing.T) { + e := NewAggregateEvaluator(parser, testBoolEvaluator) + + ok, err := e.Add(ctx, tex(`event.data.foo != "no"`)) + require.NoError(t, err) + require.False(t, ok) + require.Equal(t, 1, e.Len()) + require.Equal(t, 1, e.ConstantLen()) + require.Equal(t, 0, e.AggregateableLen()) + + // Add the same expression again. + ok, err = e.Add(ctx, tex(`event.data.foo >= "no"`)) + require.NoError(t, err) + require.False(t, ok) + require.Equal(t, 2, e.Len()) + require.Equal(t, 2, e.ConstantLen()) + require.Equal(t, 0, e.AggregateableLen()) + + // Add a new expression + ok, err = e.Add(ctx, tex(`event.data.another < "no"`)) + require.NoError(t, err) + require.False(t, ok) + require.Equal(t, 3, e.Len()) + require.Equal(t, 3, e.ConstantLen()) + require.Equal(t, 0, e.AggregateableLen()) + }) +} + +// tex represents a test Evaluable expression +type tex string + +func (e tex) Expression() string { return string(e) } + +func testBoolEvaluator(ctx context.Context, e Evaluable, input map[string]any) (bool, error) { + env, _ := cel.NewEnv( + cel.Variable("event", cel.AnyType), + cel.Variable("async", cel.AnyType), + ) + ast, _ := env.Parse(e.Expression()) + + // Create the program, refusing to short circuit if a match is found. + // + // This will add all functions from functions.StandardOverloads as we + // created the environment with our custom library. + program, err := env.Program( + ast, + cel.EvalOptions(cel.OptExhaustiveEval, cel.OptTrackState, cel.OptPartialEval), // Exhaustive, always, right now. + ) + if err != nil { + return false, err + } + result, _, err := program.Eval(input) + if result == nil { + return false, nil + } + if types.IsUnknown(result) { + // When evaluating to a strict result this should never happen. We inject a decorator + // to handle unknowns as values similar to null, and should always get a value. + return false, nil + } + if types.IsError(result) { + return false, fmt.Errorf("invalid type comparison: %v", result) + } + if err != nil { + // This shouldn't be handled, as we should get an Error type in result above. + return false, fmt.Errorf("error evaluating expression: %w", err) + } + return result.Value().(bool), nil } diff --git a/go.mod b/go.mod index e8bddfe..db66c8e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.0 require ( github.com/google/cel-go v0.18.2 + github.com/plar/go-adaptive-radix-tree v1.0.5 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 245304c..2b86762 100644 --- a/go.sum +++ b/go.sum @@ -9,12 +9,19 @@ github.com/google/cel-go v0.18.2/go.mod h1:kWcIzTsPX0zmQ+H3TirHstLLf9ep5QTsZBN9u github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/plar/go-adaptive-radix-tree v1.0.5 h1:rHR89qy/6c24TBAHullFMrJsU9hGlKmPibdBGU6/gbM= +github.com/plar/go-adaptive-radix-tree v1.0.5/go.mod h1:15VOUO7R9MhJL8HOJdpydR0rvanrtRE6fA6XSa/tqWE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= @@ -32,5 +39,6 @@ google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/groupid.go b/groupid.go index 8c61e26..ea33e26 100644 --- a/groupid.go +++ b/groupid.go @@ -15,7 +15,7 @@ func (g groupID) Size() uint16 { } func newGroupID(size uint16) groupID { - id := make([]byte, 8, 8) + id := make([]byte, 8) binary.NativeEndian.PutUint16(id, size) _, _ = rand.Read(id[2:]) return [8]byte(id[0:8]) diff --git a/iterate.go b/iterate.go deleted file mode 100644 index 3cfccb1..0000000 --- a/iterate.go +++ /dev/null @@ -1,9 +0,0 @@ -package expr - -/* -type ParsedExpressionIterator interface { - Walk(ctx context.Context, p ParsedExpression) -} - -type iter struct{} -*/ diff --git a/parser.go b/parser.go index 1bf202b..322228a 100644 --- a/parser.go +++ b/parser.go @@ -2,7 +2,6 @@ package expr import ( "context" - "crypto/md5" "fmt" "strconv" "strings" @@ -12,24 +11,55 @@ import ( "github.com/google/cel-go/common/operators" ) +const ( + VarPrefix = "vars." +) + // TreeParser parses an expression into a tree, with a root node and branches for // each subsequent OR or AND expression. type TreeParser interface { - Parse(ctx context.Context, expr string) (*ParsedExpression, error) + Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, error) +} + +// CELParser represents a CEL parser which takes an expression string +// and returns a CEL AST, any issues during parsing, and any lifted and replaced +// from the expression. +// +// By default, *cel.Env fulfils this interface. In production, it's common +// to provide a caching layer on top of *cel.Env to optimize parsing, as it's +// the slowest part of the expression process. +type CELParser interface { + Parse(expr string) (*cel.Ast, *cel.Issues, map[string]any) +} + +// EnvParser turns a *cel.Env into a CELParser. +func EnvParser(env *cel.Env) CELParser { + return envparser{env} +} + +type envparser struct { + env *cel.Env +} + +func (e envparser) Parse(txt string) (*cel.Ast, *cel.Issues, map[string]any) { + ast, iss := e.env.Parse(txt) + return ast, iss, nil } // NewTreeParser returns a new tree parser for a given *cel.Env -func NewTreeParser(env *cel.Env) (TreeParser, error) { +func NewTreeParser(ep CELParser) (TreeParser, error) { parser := &parser{ - env: env, + ep: ep, } return parser, nil } -type parser struct{ env *cel.Env } +type parser struct { + ep CELParser +} -func (p *parser) Parse(ctx context.Context, expression string) (*ParsedExpression, error) { - ast, issues := p.env.Parse(expression) +func (p *parser) Parse(ctx context.Context, eval Evaluable) (*ParsedExpression, error) { + ast, issues, vars := p.ep.Parse(eval.Expression()) if issues != nil { return nil, issues.Err() } @@ -40,12 +70,18 @@ func (p *parser) Parse(ctx context.Context, expression string) (*ParsedExpressio NavigableExpr: celast.NavigateAST(ast.NativeRep()), }, node, + vars, ) if err != nil { return nil, err } + node.normalize() - return &ParsedExpression{Root: *node}, nil + return &ParsedExpression{ + Root: *node, + Vars: vars, + Evaluable: eval, + }, nil } // ParsedExpression represents a parsed CEL expression into our higher-level AST. @@ -55,6 +91,19 @@ func (p *parser) Parse(ctx context.Context, expression string) (*ParsedExpressio type ParsedExpression struct { Root Node + // Vars represents rewritten literals within the expression. + // + // This allows us to rewrite eg. `event.data.id == "foo"` into + // `event.data.id == vars.a` such that multiple different literals + // share the same expression. Using the same expression allows us + // to cache and skip CEL parsing, which is the slowest aspect of + // expression matching. + // + Vars map[string]any + + // Evaluable stores the original evaluable interface that was parsed. + Evaluable Evaluable + // Exhaustive represents whether the parsing is exhaustive, or whether // specific CEL macros or functions were used which are not supported during // parsing. @@ -65,6 +114,15 @@ type ParsedExpression struct { // Exhaustive bool } +// RootGroups returns the top-level matching groups within an expression. This is a small +// utility to check the number of matching groups easily. +func (p ParsedExpression) RootGroups() []*Node { + if len(p.Root.Ands) == 0 && len(p.Root.Ors) > 1 { + return p.Root.Ors + } + return []*Node{&p.Root} +} + // PredicateGroup represents a group of predicates that must all pass in order to execute the // given expression. For example, this might contain two predicates representing an expression // with two operators combined with "&&". @@ -107,6 +165,9 @@ type Node struct { } func (n Node) HasPredicate() bool { + if n.Predicate == nil { + return false + } return n.Predicate.Operator != "" } @@ -204,10 +265,18 @@ func newNode() *Node { // // This is equivalent to a CEL overload/function/macro. type Predicate struct { - // Literal represents the literal value that the operator compares against. + // Literal represents the literal value that the operator compares against. If two + // variable are being compared, this is nil and LiteralIdent holds a pointer to the + // name of the second variable. Literal any + // Ident is the ident we're comparing to, eg. the variable. Ident string + + // LiteralIdent represents the second literal that we're comparing against, + // eg. in the expression "event.data.a == event.data.b this stores event.data.b + LiteralIdent *string + // Operator is the binary operator being used. NOTE: This always assumes that the // ident is to the left of the operator, eg "event.data.value > 100". If the value // is to the left of the operator, the operator will be switched @@ -216,18 +285,17 @@ type Predicate struct { } func (p Predicate) String() string { + lit := p.Literal + if p.LiteralIdent != nil { + lit = *p.LiteralIdent + } + switch str := p.Literal.(type) { case string: return fmt.Sprintf("%s %s %v", p.Ident, strings.ReplaceAll(p.Operator, "_", ""), strconv.Quote(str)) default: - return fmt.Sprintf("%s %s %v", p.Ident, strings.ReplaceAll(p.Operator, "_", ""), p.Literal) + return fmt.Sprintf("%s %s %v", p.Ident, strings.ReplaceAll(p.Operator, "_", ""), lit) } - -} - -func (p Predicate) hash() string { - sum := md5.Sum([]byte(fmt.Sprintf("%v", p))) - return string(sum[:]) } func (p Predicate) LiteralAsString() string { @@ -262,7 +330,7 @@ type expr struct { // It does this by iterating through the expression, amending the current `group` until // an or expression is found. When an or expression is found, we create another group which // is mutated by the iteration. -func navigateAST(nav expr, parent *Node) ([]*Node, error) { +func navigateAST(nav expr, parent *Node, vars map[string]any) ([]*Node, error) { // on the very first call to navigateAST, ensure that we set the first node // inside the nodemap. result := []*Node{} @@ -311,7 +379,7 @@ func navigateAST(nav expr, parent *Node) ([]*Node, error) { newParent := newNode() // For each item in the stack, recurse into that AST. - _, err := navigateAST(or, newParent) + _, err := navigateAST(or, newParent, vars) if err != nil { return nil, err } @@ -339,7 +407,7 @@ func navigateAST(nav expr, parent *Node) ([]*Node, error) { // We assume that this is being called with an ident as a comparator. // Dependign on the LHS/RHS type, we want to organize the kind into // a specific type of tree. - predicate := callToPredicate(item.NavigableExpr, item.negated) + predicate := callToPredicate(item.NavigableExpr, item.negated, vars) if predicate == nil { continue } @@ -396,52 +464,118 @@ func peek(nav expr, operator string) []expr { // callToPredicate transforms a function call within an expression (eg `>`) into // a Predicate struct for our matching engine. It ahandles normalization of // LHS/RHS plus inversions. -func callToPredicate(item celast.Expr, negated bool) *Predicate { +func callToPredicate(item celast.Expr, negated bool, vars map[string]any) *Predicate { fn := item.AsCall().FunctionName() if fn == operators.LogicalAnd || fn == operators.LogicalOr { // Quit early, as we descend into these while iterating through the tree when calling this. return nil } + // If this is in a negative expression (ie. `!(foo == bar)`), then invert the expression. + if negated { + fn = invert(fn) + } + args := item.AsCall().Args() if len(args) != 2 { return nil } var ( - ident string - literal any + identA, identB string + literal any ) for _, item := range args { switch item.Kind() { case celast.IdentKind: - ident = item.AsIdent() + if identA == "" { + identA = item.AsIdent() + } else { + identB = item.AsIdent() + } case celast.LiteralKind: literal = item.AsLiteral().Value() case celast.SelectKind: // This is an expression, ie. "event.data.foo" Iterate from the root field upwards // to get the full ident. + walked := "" for item.Kind() == celast.SelectKind { sel := item.AsSelect() - if ident == "" { - ident = sel.FieldName() + if walked == "" { + walked = sel.FieldName() } else { - ident = sel.FieldName() + "." + ident + walked = sel.FieldName() + "." + walked } item = sel.Operand() } - ident = item.AsIdent() + "." + ident + walked = item.AsIdent() + "." + walked + + if identA == "" { + identA = walked + } else { + identB = walked + } } } - if ident == "" || literal == nil { - return nil + if identA != "" && identB != "" { + // We're matching two variables together. Check to see whether any + // of these idents have variable data being passed in above. + // + // This happens when we use a parser which "lifts" variables out of + // expressions to improve cache hits. + // + // Parsing can normalize `event.data.id == "1"` to + // `event.data.id == vars.a` && vars["a"] = "1". + // + // In this case, check to see if we're using a lifted var and, if so, + // use the variable as the ident directly. + aIsVar := strings.HasPrefix(identA, VarPrefix) + bIsVar := strings.HasPrefix(identB, VarPrefix) + + if aIsVar && bIsVar { + // Someone is matching two literals together, so.... this, + // is quite dumb. + // + // Do nothing but match on two vars. + return &Predicate{ + LiteralIdent: &identB, + Ident: identA, + Operator: fn, + } + } + + if aIsVar { + if val, ok := vars[strings.TrimPrefix(identA, VarPrefix)]; ok { + // Normalize. + literal = val + identA = identB + identB = "" + } + } + + if bIsVar { + if val, ok := vars[strings.TrimPrefix(identB, VarPrefix)]; ok { + // Normalize. + literal = val + identB = "" + } + } + + if identA != "" && identB != "" { + // THese are still idents, so handle them as + // variables being compared together. + return &Predicate{ + LiteralIdent: &identB, + Ident: identA, + Operator: fn, + } + } } - // If this is in a negative expression (ie. `!(foo == bar)`), then invert the expression. - if negated { - fn = invert(fn) + if identA == "" || literal == nil { + return nil } // We always assume that the ident is on the LHS. In the case of comparisons, @@ -483,7 +617,7 @@ func callToPredicate(item celast.Expr, negated bool) *Predicate { return &Predicate{ Literal: literal, - Ident: ident, + Ident: identA, Operator: fn, } } diff --git a/parser_test.go b/parser_test.go index 3f35daf..e12a48d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -10,18 +10,17 @@ import ( "github.com/stretchr/testify/require" ) -func newTestAggregateEvaluator(t *testing.T) AggregateEvaluator { - t.Helper() - parser, _ := newParser() - return NewAggregateEvaluator(parser) -} - -func newParser() (TreeParser, error) { +func newEnv() *cel.Env { env, _ := cel.NewEnv( cel.Variable("event", cel.AnyType), cel.Variable("async", cel.AnyType), + cel.Variable("vars", cel.AnyType), ) - return NewTreeParser(env) + return env +} + +func newParser() (TreeParser, error) { + return NewTreeParser(EnvParser(newEnv())) } type parseTestInput struct { @@ -40,7 +39,14 @@ func TestParse(t *testing.T) { for _, test := range tests { parser, err := newParser() require.NoError(t, err) - actual, err := parser.Parse(ctx, test.input) + + eval := tex(test.input) + actual, err := parser.Parse(ctx, eval) + + // Shortcut to ensure the evaluable instance matches + if test.expected.Evaluable == nil { + test.expected.Evaluable = eval + } require.NoError(t, err) require.NotNil(t, actual) @@ -54,14 +60,38 @@ func TestParse(t *testing.T) { t, test.expected, *actual, - "Invalid strucutre:\n%s\nExpected: %s\n\nGot: %s", + "Invalid strucutre:\n%s\nExpected: %s\n\nGot: %s\nGroups: %d", test.input, string(a), string(b), + len(actual.RootGroups()), ) } } + t.Run("It handles ident matching", func(t *testing.T) { + ident := "vars.a" + _ = ident + + tests := []parseTestInput{ + { + input: "event == vars.a", + output: `event == vars.a`, + expected: ParsedExpression{ + Root: Node{ + Predicate: &Predicate{ + Ident: "event", + LiteralIdent: &ident, + Operator: operators.Equals, + }, + }, + }, + }, + } + + assert(t, tests) + }) + t.Run("It handles basic expressions", func(t *testing.T) { tests := []parseTestInput{ { @@ -756,15 +786,22 @@ func TestParse(t *testing.T) { }, { // Swapping the order of the expression - input: `c == 3 && (a == 1 || b == 2)`, - output: `c == 3 && (a == 1 || b == 2)`, + input: `a == 1 && b == 2 && (c == 3 || d == 4)`, + output: `a == 1 && b == 2 && (c == 3 || d == 4)`, expected: ParsedExpression{ Root: Node{ Ands: []*Node{ { Predicate: &Predicate{ - Literal: int64(3), - Ident: "c", + Literal: int64(1), + Ident: "a", + Operator: operators.Equals, + }, + }, + { + Predicate: &Predicate{ + Literal: int64(2), + Ident: "b", Operator: operators.Equals, }, }, @@ -772,15 +809,15 @@ func TestParse(t *testing.T) { Ors: []*Node{ { Predicate: &Predicate{ - Literal: int64(1), - Ident: "a", + Literal: int64(3), + Ident: "c", Operator: operators.Equals, }, }, { Predicate: &Predicate{ - Literal: int64(2), - Ident: "b", + Literal: int64(4), + Ident: "d", Operator: operators.Equals, }, }, @@ -907,19 +944,142 @@ func TestParse(t *testing.T) { } -/* -func TestParseGroupIDs(t *testing.T) { - t.Run("It creates new group IDs when parsing the same expression", func(t *testing.T) { - ctx := context.Background() - a, err := mustParser(t).Parse(ctx, "event == 'foo'") - require.NoError(t, err) - b, err := mustParser(t).Parse(ctx, "event == 'foo'") - require.NoError(t, err) - c, err := mustParser(t).Parse(ctx, "event == 'foo'") - - require.NotEqual(t, a[0].GroupID, b[0].GroupID) - require.NotEqual(t, b[0].GroupID, c[0].GroupID) - require.NotEqual(t, a[0].GroupID, c[0].GroupID) +func TestParse_LiftedVars(t *testing.T) { + ctx := context.Background() + + cachingCelParser := NewCachingParser(newEnv()) + + assert := func(t *testing.T, tests []parseTestInput) { + t.Helper() + + for _, test := range tests { + parser, err := NewTreeParser(cachingCelParser) + require.NoError(t, err) + eval := tex(test.input) + actual, err := parser.Parse(ctx, eval) + + // Shortcut to ensure the evaluable instance matches + if test.expected.Evaluable == nil { + test.expected.Evaluable = eval + } + + require.NoError(t, err) + require.NotNil(t, actual) + + require.EqualValues(t, test.output, actual.Root.String(), "String() does not match expected output") + + a, _ := json.MarshalIndent(test.expected, "", " ") + b, _ := json.MarshalIndent(actual, "", " ") + + require.EqualValues( + t, + test.expected, + *actual, + "Invalid strucutre:\n%s\nExpected: %s\n\nGot: %s\nGroups: %d", + test.input, + string(a), + string(b), + len(actual.RootGroups()), + ) + } + } + + t.Run("It handles basic expressions", func(t *testing.T) { + tests := []parseTestInput{ + { + input: `event == "foo"`, + output: `event == "foo"`, + expected: ParsedExpression{ + Root: Node{ + Predicate: &Predicate{ + Literal: "foo", + Ident: "event", + Operator: operators.Equals, + }, + }, + Vars: map[string]any{ + "a": "foo", + }, + }, + }, + { + input: `event == "bar"`, + output: `event == "bar"`, + expected: ParsedExpression{ + Root: Node{ + Predicate: &Predicate{ + Literal: "bar", + Ident: "event", + Operator: operators.Equals, + }, + }, + Vars: map[string]any{ + "a": "bar", + }, + }, + }, + { + input: `"bar" == event`, + output: `event == "bar"`, + expected: ParsedExpression{ + Root: Node{ + Predicate: &Predicate{ + Literal: "bar", + Ident: "event", + Operator: operators.Equals, + }, + }, + Vars: map[string]any{ + "a": "bar", + }, + }, + }, + } + + assert(t, tests) + + // We should have had one hit, as `event == "bar"` and `event == "foo"` + // were lifted into the same expression `event == vars.a` + require.EqualValues(t, 1, cachingCelParser.(*cachingParser).Hits()) }) } -*/ + +func TestRootGroups(t *testing.T) { + r := require.New(t) + ctx := context.Background() + parser, err := newParser() + + r.NoError(err) + + t.Run("With single groups", func(t *testing.T) { + actual, err := parser.Parse(ctx, tex("a == 1")) + r.NoError(err) + r.Equal(1, len(actual.RootGroups())) + r.Equal(&actual.Root, actual.RootGroups()[0]) + + actual, err = parser.Parse(ctx, tex("a == 1 && b == 2")) + r.NoError(err) + r.Equal(1, len(actual.RootGroups())) + r.Equal(&actual.Root, actual.RootGroups()[0]) + + actual, err = parser.Parse(ctx, tex("root == 'yes' && (a == 1 || b == 2)")) + r.NoError(err) + r.Equal(1, len(actual.RootGroups())) + r.Equal(&actual.Root, actual.RootGroups()[0]) + }) + + t.Run("With an or", func(t *testing.T) { + actual, err := parser.Parse(ctx, tex("a == 1 || b == 2")) + r.NoError(err) + r.Equal(2, len(actual.RootGroups())) + + actual, err = parser.Parse(ctx, tex("a == 1 || b == 2 || c == 3")) + r.NoError(err) + r.Equal(3, len(actual.RootGroups())) + + actual, err = parser.Parse(ctx, tex("a == 1 && b == 2 || c == 3")) + r.NoError(err) + r.Equal(2, len(actual.RootGroups())) + }) + +} diff --git a/tree.go b/tree.go index 375d96c..2858eb1 100644 --- a/tree.go +++ b/tree.go @@ -1,6 +1,8 @@ package expr -import "context" +import ( + "context" +) type TreeType int @@ -17,10 +19,11 @@ const ( // For example, an expression may check string equality using an // ART tree, while LTE operations may check against a b+-tree. type PredicateTree interface { - Add(ctx context.Context, p Predicate) error + Add(ctx context.Context, p ExpressionPart) error + Search(ctx context.Context, input any) (*Leaf, bool) } -// leaf represents the leaf within a tree. This stores all expressions +// Leaf represents the leaf within a tree. This stores all expressions // which match the given expression. // // For example, adding two expressions each matching "event.data == 'foo'" @@ -29,7 +32,6 @@ type PredicateTree interface { // // Note that there are many sub-clauses which need to be matched. Each // leaf is a subset of a full expression. Therefore, - type Leaf struct { Evals []ExpressionPart } @@ -46,5 +48,6 @@ type ExpressionPart struct { // // This lets us determine whether the entire group has been matched. GroupID groupID - Evaluable Evaluable + Predicate Predicate + Parsed *ParsedExpression } diff --git a/tree_art.go b/tree_art.go new file mode 100644 index 0000000..ccaf752 --- /dev/null +++ b/tree_art.go @@ -0,0 +1,86 @@ +package expr + +import ( + "context" + "fmt" + "sync" + "unsafe" + + art "github.com/plar/go-adaptive-radix-tree" +) + +var ( + ErrInvalidType = fmt.Errorf("invalid type for tree") +) + +func newArtTree() PredicateTree { + return &artTree{ + lock: &sync.RWMutex{}, + Tree: art.New(), + } +} + +type artTree struct { + lock *sync.RWMutex + art.Tree +} + +func (a *artTree) Search(ctx context.Context, input any) (*Leaf, bool) { + var key art.Key + + switch val := input.(type) { + case art.Key: + key = val + case []byte: + key = val + case string: + key = artKeyFromString(val) + } + + if len(key) == 0 { + return nil, false + } + + val, ok := a.Tree.Search(key) + if !ok { + return nil, false + } + return val.(*Leaf), true +} + +func (a *artTree) Add(ctx context.Context, p ExpressionPart) error { + str, ok := p.Predicate.Literal.(string) + if !ok { + return ErrInvalidType + } + + key := artKeyFromString(str) + + // Don't allow multiple gorutines to modify the tree simultaneously. + a.lock.Lock() + defer a.lock.Unlock() + + val, ok := a.Tree.Search(key) + if !ok { + // Insert the ExpressionPart as-is. + a.Insert(key, art.Value(&Leaf{ + Evals: []ExpressionPart{p}, + })) + return nil + } + + // Add the expressionpart as an expression matched by the already-existing + // value. Many expressions may match on the same string, eg. a user may set + // up 3 matches for order ID "abc". All 3 matches must be evaluated. + next := val.(*Leaf) + next.Evals = append(next.Evals, p) + a.Insert(key, next) + return nil +} + +func artKeyFromString(str string) art.Key { + // Zero-allocation string to byte conversion for speed. + strd := unsafe.StringData(str) + return art.Key(unsafe.Slice(strd, len(str))) + +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/.gitignore b/vendor/github.com/plar/go-adaptive-radix-tree/.gitignore new file mode 100644 index 0000000..daf913b --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/.travis.yml b/vendor/github.com/plar/go-adaptive-radix-tree/.travis.yml new file mode 100644 index 0000000..7f49ccf --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/.travis.yml @@ -0,0 +1,18 @@ +language: go +sudo: false + +go: +- "1.13" +- "1.14" +- tip + +install: + - make bootstrap + +before_script: + - go vet ./... + +script: + - make + - make test-cover-builder + - $GOPATH/bin/goveralls -service=travis-ci -coverprofile=/tmp/art_coverage.out diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/LICENSE b/vendor/github.com/plar/go-adaptive-radix-tree/LICENSE new file mode 100644 index 0000000..fa99eac --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Pavel Larkin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/README.md b/vendor/github.com/plar/go-adaptive-radix-tree/README.md new file mode 100644 index 0000000..55a0bb3 --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/README.md @@ -0,0 +1,98 @@ +An Adaptive Radix Tree Implementation in Go +==== + +[![Build Status](https://travis-ci.org/plar/go-adaptive-radix-tree.svg?branch=master)](https://travis-ci.org/plar/go-adaptive-radix-tree) [![Coverage Status](https://coveralls.io/repos/github/plar/go-adaptive-radix-tree/badge.svg?branch=master&v=1)](https://coveralls.io/github/plar/go-adaptive-radix-tree?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/plar/go-adaptive-radix-tree)](https://goreportcard.com/report/github.com/plar/go-adaptive-radix-tree) [![GoDoc](https://godoc.org/github.com/plar/go-adaptive-radix-tree?status.svg)](http://godoc.org/github.com/plar/go-adaptive-radix-tree) + +This library provides a Go implementation of the Adaptive Radix Tree (ART). + +Features: +* Lookup performance surpasses highly tuned alternatives +* Support for highly efficient insertions and deletions +* Space efficient +* Performance is comparable to hash tables +* Maintains the data in sorted order, which enables additional operations like range scan and prefix lookup +* `O(k)` search/insert/delete operations, where `k` is the length of the key +* Minimum / Maximum value lookups +* Ordered iteration +* Prefix-based iteration +* Support for keys with null bytes, any byte array could be a key + +# Usage + +```go +package main + +import ( + "fmt" + "github.com/plar/go-adaptive-radix-tree" +) + +func main() { + + tree := art.New() + + tree.Insert(art.Key("Hi, I'm Key"), "Nice to meet you, I'm Value") + value, found := tree.Search(art.Key("Hi, I'm Key")) + if found { + fmt.Printf("Search value=%v\n", value) + } + + tree.ForEach(func(node art.Node) bool { + fmt.Printf("Callback value=%v\n", node.Value()) + return true + }) + + for it := tree.Iterator(); it.HasNext(); { + value, _ := it.Next() + fmt.Printf("Iterator value=%v\n", value.Value()) + } +} + +// Output: +// Search value=Nice to meet you, I'm Value +// Callback value=Nice to meet you, I'm Value +// Iterator value=Nice to meet you, I'm Value + +``` + +# Documentation + +Check out the documentation on [godoc.org](http://godoc.org/github.com/plar/go-adaptive-radix-tree) + +# Performance + +[plar/go-adaptive-radix-tree](https://github.com/plar/go-adaptive-radix-tree) outperforms [kellydunn/go-art](https://github.com/kellydunn/go-art) by avoiding memory allocations during search operations. +It also provides prefix based iteration over the tree. + +Benchmarks were performed on datasets extracted from different projects: +- The "Words" dataset contains a list of 235,886 english words. [2] +- The "UUIDs" dataset contains 100,000 uuids. [2] +- The "HSK Words" dataset contains 4,995 words. [4] + +|**go-adaptive-radix-tree**| # | Average time |Bytes per operation|Allocs per operation | +|:-------------------------|---:|------------------:|------------------:|--------------------:| +| Tree Insert Words | 9 | 117,888,698 ns/op | 37,942,744 B/op | 1,214,541 allocs/op | +| Tree Search Words | 26 | 44,555,608 ns/op | 0 B/op | 0 allocs/op | +| Tree Insert UUIDs | 18 | 59,360,135 ns/op | 18,375,723 B/op | 485,057 allocs/op | +| Tree Search UUIDs | 54 | 21,265,931 ns/op | 0 B/op | 0 allocs/op | +|**go-art** | | | | | +| Tree Insert Words | 5 | 272,047,975 ns/op | 81,628,987 B/op | 2,547,316 allocs/op | +| Tree Search Words | 10 | 129,011,177 ns/op | 13,272,278 B/op | 1,659,033 allocs/op | +| Tree Insert UUIDs | 10 | 140,309,246 ns/op | 33,678,160 B/op | 874,561 allocs/op | +| Tree Search UUIDs | 20 | 82,120,943 ns/op | 3,883,131 B/op | 485,391 allocs/op | + +To see more benchmarks just run + +``` +$ make benchmark +``` + +# References + +[1] [The Adaptive Radix Tree: ARTful Indexing for Main-Memory Databases (Specification)](http://www-db.in.tum.de/~leis/papers/ART.pdf) + +[2] [C99 implementation of the Adaptive Radix Tree](https://github.com/armon/libart) + +[3] [Another Adaptive Radix Tree implementation in Go](https://github.com/kellydunn/go-art) + +[4] [HSK Words](http://hskhsk.pythonanywhere.com/hskwords). HSK(Hanyu Shuiping Kaoshi) - Standardized test of Standard Mandarin Chinese proficiency. diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/api.go b/vendor/github.com/plar/go-adaptive-radix-tree/api.go new file mode 100644 index 0000000..ac3f1c3 --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/api.go @@ -0,0 +1,118 @@ +package art + +import "errors" + +// A constant exposing all node types. +const ( + Leaf Kind = 0 + Node4 Kind = 1 + Node16 Kind = 2 + Node48 Kind = 3 + Node256 Kind = 4 +) + +// Traverse Options. +const ( + // Iterate only over leaf nodes. + TraverseLeaf = 1 + + // Iterate only over non-leaf nodes. + TraverseNode = 2 + + // Iterate over all nodes in the tree. + TraverseAll = TraverseLeaf | TraverseNode +) + +// These errors can be returned when iteration over the tree. +var ( + ErrConcurrentModification = errors.New("Concurrent modification has been detected") + ErrNoMoreNodes = errors.New("There are no more nodes in the tree") +) + +// Kind is a node type. +type Kind int + +// Key Type. +// Key can be a set of any characters include unicode chars with null bytes. +type Key []byte + +// Value type. +type Value interface{} + +// Callback function type for tree traversal. +// if the callback function returns false then iteration is terminated. +type Callback func(node Node) (cont bool) + +// Node interface. +type Node interface { + // Kind returns node type. + Kind() Kind + + // Key returns leaf's key. + // This method is only valid for leaf node, + // if its called on non-leaf node then returns nil. + Key() Key + + // Value returns leaf's value. + // This method is only valid for leaf node, + // if its called on non-leaf node then returns nil. + Value() Value +} + +// Iterator iterates over nodes in key order. +type Iterator interface { + // Returns true if the iteration has more nodes when traversing the tree. + HasNext() bool + + // Returns the next element in the tree and advances the iterator position. + // Returns ErrNoMoreNodes error if there are no more nodes in the tree. + // Check if there is a next node with HasNext method. + // Returns ErrConcurrentModification error if the tree has been structurally + // modified after the iterator was created. + Next() (Node, error) +} + +// Tree is an Adaptive Radix Tree interface. +type Tree interface { + // Insert a new key into the tree. + // If the key already in the tree then return oldValue, true and nil, false otherwise. + Insert(key Key, value Value) (oldValue Value, updated bool) + + // Delete removes a key from the tree and key's value, true is returned. + // If the key does not exists then nothing is done and nil, false is returned. + Delete(key Key) (value Value, deleted bool) + + // Search returns the value of the specific key. + // If the key exists then return value, true and nil, false otherwise. + Search(key Key) (value Value, found bool) + + // ForEach executes a provided callback once per leaf node by default. + // The callback iteration is terminated if the callback function returns false. + // Pass TraverseXXX as an options to execute a provided callback + // once per NodeXXX type in the tree. + ForEach(callback Callback, options ...int) + + // ForEachPrefix executes a provided callback once per leaf node that + // leaf's key starts with the given keyPrefix. + // The callback iteration is terminated if the callback function returns false. + ForEachPrefix(keyPrefix Key, callback Callback) + + // Iterator returns an iterator for preorder traversal over leaf nodes by default. + // Pass TraverseXXX as an options to return an iterator for preorder traversal over all NodeXXX types. + Iterator(options ...int) Iterator + //IteratorPrefix(key Key) Iterator + + // Minimum returns the minimum valued leaf, true if leaf is found and nil, false otherwise. + Minimum() (min Value, found bool) + + // Maximum returns the maximum valued leaf, true if leaf is found and nil, false otherwise. + Maximum() (max Value, found bool) + + // Returns size of the tree + Size() int +} + +// New creates a new adaptive radix tree +func New() Tree { + return newTree() +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/consts.go b/vendor/github.com/plar/go-adaptive-radix-tree/consts.go new file mode 100644 index 0000000..0e2919c --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/consts.go @@ -0,0 +1,21 @@ +package art + +// node constraints +const ( + node4Min = 2 + node4Max = 4 + + node16Min = node4Max + 1 + node16Max = 16 + + node48Min = node16Max + 1 + node48Max = 48 + + node256Min = node48Max + 1 + node256Max = 256 +) + +const ( + // MaxPrefixLen is maximum prefix length for internal nodes. + MaxPrefixLen = 10 +) diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/doc.go b/vendor/github.com/plar/go-adaptive-radix-tree/doc.go new file mode 100644 index 0000000..35ec70f --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/doc.go @@ -0,0 +1,49 @@ +// Package art implements an Adapative Radix Tree(ART) in pure Go. +// Note that this implementation is not thread-safe but it could be really easy to implement. +// +// The design of ART is based on "The Adaptive Radix Tree: ARTful Indexing for Main-Memory Databases" [1]. +// +// Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/plar/go-adaptive-radix-tree" +// ) +// +// func main() { +// +// tree := art.New() +// +// tree.Insert(art.Key("Hi, I'm Key"), "Nice to meet you, I'm Value") +// value, found := tree.Search(art.Key("Hi, I'm Key")) +// if found { +// fmt.Printf("Search value=%v\n", value) +// } +// +// tree.ForEach(func(node art.Node) bool { +// fmt.Printf("Callback value=%v\n", node.Value()) +// return true +// } +// +// for it := tree.Iterator(); it.HasNext(); { +// value, _ := it.Next() +// fmt.Printf("Iterator value=%v\n", value.Value()) +// } +// } +// +// // Output: +// // Search value=Nice to meet you, I'm Value +// // Callback value=Nice to meet you, I'm Value +// // Iterator value=Nice to meet you, I'm Value +// +// +// Also the current implementation was inspired by [2] and [3] +// +// [1] http://db.in.tum.de/~leis/papers/ART.pdf (Specification) +// +// [2] https://github.com/armon/libart (C99 implementation) +// +// [3] https://github.com/kellydunn/go-art (other Go implementation) +package art diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/factory.go b/vendor/github.com/plar/go-adaptive-radix-tree/factory.go new file mode 100644 index 0000000..a625374 --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/factory.go @@ -0,0 +1,54 @@ +package art + +import ( + "unsafe" +) + +type nodeFactory interface { + newNode4() *artNode + newNode16() *artNode + newNode48() *artNode + newNode256() *artNode + newLeaf(key Key, value interface{}) *artNode +} + +// make sure that objFactory implements all methods of nodeFactory interface +var _ nodeFactory = &objFactory{} + +var factory = newObjFactory() + +func newTree() *tree { + return &tree{} +} + +type objFactory struct{} + +func newObjFactory() nodeFactory { + return &objFactory{} +} + +// Simple obj factory implementation +func (f *objFactory) newNode4() *artNode { + return &artNode{kind: Node4, ref: unsafe.Pointer(new(node4))} +} + +func (f *objFactory) newNode16() *artNode { + return &artNode{kind: Node16, ref: unsafe.Pointer(&node16{})} +} + +func (f *objFactory) newNode48() *artNode { + return &artNode{kind: Node48, ref: unsafe.Pointer(&node48{})} +} + +func (f *objFactory) newNode256() *artNode { + return &artNode{kind: Node256, ref: unsafe.Pointer(&node256{})} +} + +func (f *objFactory) newLeaf(key Key, value interface{}) *artNode { + clonedKey := make(Key, len(key)) + copy(clonedKey, key) + return &artNode{ + kind: Leaf, + ref: unsafe.Pointer(&leaf{key: clonedKey, value: value}), + } +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/makefile b/vendor/github.com/plar/go-adaptive-radix-tree/makefile new file mode 100644 index 0000000..abf6e3f --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/makefile @@ -0,0 +1,53 @@ +EXTERNAL_TOOLS=\ + golang.org/x/tools/cmd/cover \ + golang.org/x/tools/cmd/vet \ + github.com/mattn/goveralls \ + github.com/stretchr/testify/assert + + +all: all-tests + @echo "*** Done!" + +get: + @echo "*** Resolve dependencies..." + @go get -v . + +all-tests: + @echo "*** Run tests..." + @go test . + +benchmark: + @echo "*** Run benchmarks..." + @go test -v -benchmem -bench=. -run=^a + +test-race: + @echo "*** Run tests with race condition..." + @go test --race -v . + +test-cover-builder: + @go test -covermode=count -coverprofile=/tmp/art.out . + + @rm -f /tmp/art_coverage.out + @echo "mode: count" > /tmp/art_coverage.out + @cat /tmp/art.out | tail -n +2 >> /tmp/art_coverage.out + @rm /tmp/art.out + +test-cover: test-cover-builder + @go tool cover -html=/tmp/art_coverage.out + +build: + @echo "*** Build project..." + @go build -v . + +build-asm: + @go build -a -work -v -gcflags="-S -B -C" . + +build-race: + @echo "*** Build project with race condition..." + @go build --race -v . + +bootstrap: + @for tool in $(EXTERNAL_TOOLS) ; do \ + echo "Installing $$tool" ; \ + go get $$tool; \ + done diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/node.go b/vendor/github.com/plar/go-adaptive-radix-tree/node.go new file mode 100644 index 0000000..8f55bba --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/node.go @@ -0,0 +1,794 @@ +package art + +import ( + "bytes" + "math/bits" + "unsafe" +) + +type prefix [MaxPrefixLen]byte + +// ART node stores all available nodes, leaf and node type +type artNode struct { + ref unsafe.Pointer + kind Kind +} + +// a key with the null suffix will be stored as zeroChild +type node struct { + prefixLen uint32 + prefix prefix + numChildren uint16 + zeroChild *artNode +} + +// Node with 4 children +type node4 struct { + node + children [node4Max]*artNode + keys [node4Max]byte + present [node4Max]byte +} + +// Node with 16 children +type node16 struct { + node + children [node16Max]*artNode + keys [node16Max]byte + present uint16 // need 16 bits for keys +} + +// Node with 48 children +const ( + n48s = 6 // 2^n48s == n48m + n48m = 64 // it should be sizeof(node48.present[0]) +) + +type node48 struct { + node + children [node48Max]*artNode + keys [node256Max]byte + present [4]uint64 // need 256 bits for keys +} + +// Node with 256 children +type node256 struct { + node + children [node256Max]*artNode +} + +// Leaf node with variable key length +type leaf struct { + key Key + value interface{} +} + +// String returns string representation of the Kind value +func (k Kind) String() string { + return []string{"Leaf", "Node4", "Node16", "Node48", "Node256"}[k] +} + +func (k Key) charAt(pos int) byte { + if pos < 0 || pos >= len(k) { + return 0 + } + return k[pos] +} + +func (k Key) valid(pos int) bool { + return pos >= 0 && pos < len(k) +} + +// Node interface implementation +func (an *artNode) node() *node { + return (*node)(an.ref) +} + +func (an *artNode) Kind() Kind { + return an.kind +} + +func (an *artNode) Key() Key { + if an.isLeaf() { + return an.leaf().key + } + + return nil +} + +func (an *artNode) Value() Value { + if an.isLeaf() { + return an.leaf().value + } + + return nil +} + +func (an *artNode) isLeaf() bool { + return an.kind == Leaf +} + +func (an *artNode) setPrefix(key Key, prefixLen uint32) *artNode { + node := an.node() + node.prefixLen = prefixLen + for i := uint32(0); i < min(prefixLen, MaxPrefixLen); i++ { + node.prefix[i] = key[i] + } + + return an +} + +func (an *artNode) matchDeep(key Key, depth uint32) uint32 /* mismatch index*/ { + mismatchIdx := an.match(key, depth) + if mismatchIdx < MaxPrefixLen { + return mismatchIdx + } + + leaf := an.minimum() + limit := min(uint32(len(leaf.key)), uint32(len(key))) - depth + for ; mismatchIdx < limit; mismatchIdx++ { + if leaf.key[mismatchIdx+depth] != key[mismatchIdx+depth] { + break + } + } + + return mismatchIdx +} + +// Find the minimum leaf under a artNode +func (an *artNode) minimum() *leaf { + switch an.kind { + case Leaf: + return an.leaf() + + case Node4: + node := an.node4() + if node.zeroChild != nil { + return node.zeroChild.minimum() + } else if node.children[0] != nil { + return node.children[0].minimum() + } + + case Node16: + node := an.node16() + if node.zeroChild != nil { + return node.zeroChild.minimum() + } else if node.children[0] != nil { + return node.children[0].minimum() + } + + case Node48: + node := an.node48() + if node.zeroChild != nil { + return node.zeroChild.minimum() + } + + idx := uint8(0) + for node.present[idx>>n48s]&(1< 0 { + idx := 0 + for ; node.children[idx] == nil; idx++ { + // find 1st non empty + } + return node.children[idx].minimum() + } + } + + return nil // that should never happen in normal case +} + +func (an *artNode) maximum() *leaf { + switch an.kind { + case Leaf: + return an.leaf() + + case Node4: + node := an.node4() + return node.children[node.numChildren-1].maximum() + + case Node16: + node := an.node16() + return node.children[node.numChildren-1].maximum() + + case Node48: + idx := uint8(node256Max - 1) + node := an.node48() + for node.present[idx>>n48s]&(1<>n48s] & (1 << (c % n48m)); s > 0 { + if idx := int(node.keys[c]); idx >= 0 { + return idx + } + } + + case Node256: + return int(c) + } + + return -1 // not found +} + +var nodeNotFound *artNode + +func (an *artNode) findChild(c byte, valid bool) **artNode { + node := an.node() + + if !valid { + return &node.zeroChild + } + + idx := an.index(c) + if idx != -1 { + switch an.kind { + case Node4: + return &an.node4().children[idx] + + case Node16: + return &an.node16().children[idx] + + case Node48: + return &an.node48().children[idx] + + case Node256: + return &an.node256().children[idx] + } + } + + return &nodeNotFound +} + +func (an *artNode) node4() *node4 { + return (*node4)(an.ref) +} + +func (an *artNode) node16() *node16 { + return (*node16)(an.ref) +} + +func (an *artNode) node48() *node48 { + return (*node48)(an.ref) +} + +func (an *artNode) node256() *node256 { + return (*node256)(an.ref) +} + +func (an *artNode) leaf() *leaf { + return (*leaf)(an.ref) +} + +func (an *artNode) _addChild4(c byte, valid bool, child *artNode) bool { + node := an.node4() + + // grow to node16 + if node.numChildren >= node4Max { + newNode := an.grow() + newNode.addChild(c, valid, child) + replaceNode(an, newNode) + return true + } + + // zero byte in the key + if !valid { + node.zeroChild = child + return false + } + + // just add a new child + i := uint16(0) + for ; i < node.numChildren; i++ { + if c < node.keys[i] { + break + } + } + + limit := node.numChildren - i + for j := limit; limit > 0 && j > 0; j-- { + node.keys[i+j] = node.keys[i+j-1] + node.present[i+j] = node.present[i+j-1] + node.children[i+j] = node.children[i+j-1] + } + node.keys[i] = c + node.present[i] = 1 + node.children[i] = child + node.numChildren++ + return false +} + +func (an *artNode) _addChild16(c byte, valid bool, child *artNode) bool { + node := an.node16() + + if node.numChildren >= node16Max { + newNode := an.grow() + newNode.addChild(c, valid, child) + replaceNode(an, newNode) + return true + } + + if !valid { + node.zeroChild = child + return false + } + + idx := node.numChildren + bitfield := uint(0) + for i := uint(0); i < node16Max; i++ { + if node.keys[i] > c { + bitfield |= (1 << i) + } + } + mask := (1 << node.numChildren) - 1 + bitfield &= uint(mask) + if bitfield != 0 { + idx = uint16(bits.TrailingZeros(bitfield)) + } + + for i := node.numChildren; i > uint16(idx); i-- { + node.keys[i] = node.keys[i-1] + node.present = (node.present & ^(1 << i)) | ((node.present & (1 << (i - 1))) << 1) + node.children[i] = node.children[i-1] + } + + node.keys[idx] = c + node.present |= (1 << uint16(idx)) + node.children[idx] = child + node.numChildren++ + return false +} + +func (an *artNode) _addChild48(c byte, valid bool, child *artNode) bool { + node := an.node48() + if node.numChildren >= node48Max { + newNode := an.grow() + newNode.addChild(c, valid, child) + replaceNode(an, newNode) + return true + } + + if !valid { + node.zeroChild = child + return false + } + + index := byte(0) + for node.children[index] != nil { + index++ + } + + node.keys[c] = index + node.present[c>>n48s] |= (1 << (c % n48m)) + node.children[index] = child + node.numChildren++ + return false +} + +func (an *artNode) _addChild256(c byte, valid bool, child *artNode) bool { + node := an.node256() + if !valid { + node.zeroChild = child + } else { + node.numChildren++ + node.children[c] = child + } + + return false +} + +func (an *artNode) addChild(c byte, valid bool, child *artNode) bool { + switch an.kind { + case Node4: + return an._addChild4(c, valid, child) + + case Node16: + return an._addChild16(c, valid, child) + + case Node48: + return an._addChild48(c, valid, child) + + case Node256: + return an._addChild256(c, valid, child) + } + + return false +} + +func (an *artNode) _deleteChild4(c byte, valid bool) uint16 { + node := an.node4() + if !valid { + node.zeroChild = nil + } else if idx := an.index(c); idx >= 0 { + node.numChildren-- + + node.keys[idx] = 0 + node.present[idx] = 0 + node.children[idx] = nil + + for i := uint16(idx); i <= node.numChildren && i+1 < node4Max; i++ { + node.keys[i] = node.keys[i+1] + node.present[i] = node.present[i+1] + node.children[i] = node.children[i+1] + } + + node.keys[node.numChildren] = 0 + node.present[node.numChildren] = 0 + node.children[node.numChildren] = nil + } + + // we have to return the number of children for the current node(node4) as + // `node.numChildren` plus one if null node is not nil. + // `Shrink` method can be invoked after this method, + // `Shrink` can convert this node into a leaf node type. + // For all higher nodes(16/48/256) we simply copy null node to a smaller node + // see deleteChild() and shrink() methods for implementation details + numChildren := node.numChildren + if node.zeroChild != nil { + numChildren++ + } + + return numChildren +} + +func (an *artNode) _deleteChild16(c byte, valid bool) uint16 { + node := an.node16() + if !valid { + node.zeroChild = nil + } else if idx := an.index(c); idx >= 0 { + node.numChildren-- + node.keys[idx] = 0 + node.present &= ^(1 << uint16(idx)) + node.children[idx] = nil + + for i := uint16(idx); i <= node.numChildren && i+1 < node16Max; i++ { + node.keys[i] = node.keys[i+1] + node.present = (node.present & ^(1 << i)) | ((node.present & (1 << (i + 1))) >> 1) + node.children[i] = node.children[i+1] + } + + node.keys[node.numChildren] = 0 + node.present &= ^(1 << node.numChildren) + node.children[node.numChildren] = nil + } + + return node.numChildren +} + +func (an *artNode) _deleteChild48(c byte, valid bool) uint16 { + node := an.node48() + if !valid { + node.zeroChild = nil + } else if idx := an.index(c); idx >= 0 && node.children[idx] != nil { + node.children[idx] = nil + node.keys[c] = 0 + node.present[c>>n48s] &= ^(1 << (c % n48m)) + node.numChildren-- + } + + return node.numChildren +} + +func (an *artNode) _deleteChild256(c byte, valid bool) uint16 { + node := an.node256() + if !valid { + node.zeroChild = nil + return node.numChildren + } else if idx := an.index(c); node.children[idx] != nil { + node.children[idx] = nil + node.numChildren-- + } + + return node.numChildren +} + +func (an *artNode) deleteChild(c byte, valid bool) bool { + var ( + numChildren uint16 + minChildren uint16 + ) + + deleted := false + switch an.kind { + case Node4: + numChildren = an._deleteChild4(c, valid) + minChildren = node4Min + deleted = true + + case Node16: + numChildren = an._deleteChild16(c, valid) + minChildren = node16Min + deleted = true + + case Node48: + numChildren = an._deleteChild48(c, valid) + minChildren = node48Min + deleted = true + + case Node256: + numChildren = an._deleteChild256(c, valid) + minChildren = node256Min + deleted = true + } + + if deleted && numChildren < minChildren { + newNode := an.shrink() + replaceNode(an, newNode) + return true + } + + return false +} + +func (an *artNode) copyMeta(src *artNode) *artNode { + if src == nil { + return an + } + + d := an.node() + s := src.node() + + d.numChildren = s.numChildren + d.prefixLen = s.prefixLen + + for i, limit := uint32(0), min(s.prefixLen, MaxPrefixLen); i < limit; i++ { + d.prefix[i] = s.prefix[i] + } + + return an +} + +func (an *artNode) grow() *artNode { + switch an.kind { + case Node4: + node := factory.newNode16().copyMeta(an) + + d := node.node16() + s := an.node4() + d.zeroChild = s.zeroChild + + for i := uint16(0); i < s.numChildren; i++ { + if s.present[i] != 0 { + d.keys[i] = s.keys[i] + d.present |= (1 << i) + d.children[i] = s.children[i] + } + } + + return node + + case Node16: + node := factory.newNode48().copyMeta(an) + + d := node.node48() + s := an.node16() + d.zeroChild = s.zeroChild + + var numChildren byte + for i := uint16(0); i < s.numChildren; i++ { + if s.present&(1<>n48s] |= (1 << (ch % n48m)) + d.children[numChildren] = s.children[i] + numChildren++ + } + } + + return node + + case Node48: + node := factory.newNode256().copyMeta(an) + + d := node.node256() + s := an.node48() + d.zeroChild = s.zeroChild + + for i := uint16(0); i < node256Max; i++ { + if s.present[i>>n48s]&(1<<(i%n48m)) != 0 { + d.children[i] = s.children[s.keys[i]] + } + } + + return node + } + + return nil +} + +func (an *artNode) shrink() *artNode { + switch an.kind { + case Node4: + node4 := an.node4() + child := node4.children[0] + if child == nil { + child = node4.zeroChild + } + + if child.isLeaf() { + return child + } + + curPrefixLen := node4.prefixLen + if curPrefixLen < MaxPrefixLen { + node4.prefix[curPrefixLen] = node4.keys[0] + curPrefixLen++ + } + + childNode := child.node() + if curPrefixLen < MaxPrefixLen { + childPrefixLen := min(childNode.prefixLen, MaxPrefixLen-curPrefixLen) + for i := uint32(0); i < childPrefixLen; i++ { + node4.prefix[curPrefixLen+i] = childNode.prefix[i] + } + curPrefixLen += childPrefixLen + } + + for i := uint32(0); i < min(curPrefixLen, MaxPrefixLen); i++ { + childNode.prefix[i] = node4.prefix[i] + } + childNode.prefixLen += node4.prefixLen + 1 + + return child + + case Node16: + node16 := an.node16() + + newNode := factory.newNode4().copyMeta(an) + node4 := newNode.node4() + node4.numChildren = 0 + for i := uint16(0); i < node4Max; i++ { + node4.keys[i] = node16.keys[i] + if node16.present&(1<>n48s]&(1<<(uint16(i)%n48m)) == 0 { + continue + } + + if child := node48.children[idx]; child != nil { + node16.children[node16.numChildren] = child + node16.keys[node16.numChildren] = byte(i) + node16.present |= (1 << node16.numChildren) + node16.numChildren++ + } + } + + node16.zeroChild = node48.zeroChild + + return newNode + + case Node256: + node256 := an.node256() + + newNode := factory.newNode48().copyMeta(an) + node48 := newNode.node48() + node48.numChildren = 0 + for i, child := range node256.children { + if child != nil { + node48.children[node48.numChildren] = child + node48.keys[byte(i)] = byte(node48.numChildren) + node48.present[uint16(i)>>n48s] |= (1 << (uint16(i) % n48m)) + node48.numChildren++ + } + } + + node48.zeroChild = node256.zeroChild + + return newNode + } + + return nil +} + +// Leaf methods +func (l *leaf) match(key Key) bool { + if len(key) == 0 && len(l.key) == 0 { + return true + } + + if key == nil || len(l.key) != len(key) { + return false + } + + return bytes.Compare(l.key[:len(key)], key) == 0 +} + +func (l *leaf) prefixMatch(key Key) bool { + if key == nil || len(l.key) < len(key) { + return false + } + + return bytes.Compare(l.key[:len(key)], key) == 0 +} + +// Base node methods +func (an *artNode) match(key Key, depth uint32) uint32 /* 1st mismatch index*/ { + idx := uint32(0) + if len(key)-int(depth) < 0 { + return idx + } + + node := an.node() + + limit := min(min(node.prefixLen, MaxPrefixLen), uint32(len(key))-depth) + for ; idx < limit; idx++ { + if node.prefix[idx] != key[idx+depth] { + return idx + } + } + + return idx +} + +// Node helpers +func replaceRef(oldNode **artNode, newNode *artNode) { + *oldNode = newNode +} + +func replaceNode(oldNode *artNode, newNode *artNode) { + *oldNode = *newNode +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/tree.go b/vendor/github.com/plar/go-adaptive-radix-tree/tree.go new file mode 100644 index 0000000..84b8e54 --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/tree.go @@ -0,0 +1,238 @@ +package art + +type tree struct { + // version field is updated by each tree modification + version int + + root *artNode + size int +} + +// make sure that tree implements all methods from the Tree interface +var _ Tree = &tree{} + +func (t *tree) Insert(key Key, value Value) (Value, bool) { + oldValue, updated := t.recursiveInsert(&t.root, key, value, 0) + if !updated { + t.version++ + t.size++ + } + + return oldValue, updated +} + +func (t *tree) Delete(key Key) (Value, bool) { + value, deleted := t.recursiveDelete(&t.root, key, 0) + if deleted { + t.version++ + t.size-- + return value, true + } + + return nil, false +} + +func (t *tree) Search(key Key) (Value, bool) { + current := t.root + depth := uint32(0) + for current != nil { + if current.isLeaf() { + leaf := current.leaf() + if leaf.match(key) { + return leaf.value, true + } + + return nil, false + } + + curNode := current.node() + + if curNode.prefixLen > 0 { + prefixLen := current.match(key, depth) + if prefixLen != min(curNode.prefixLen, MaxPrefixLen) { + return nil, false + } + depth += curNode.prefixLen + } + + next := current.findChild(key.charAt(int(depth)), key.valid(int(depth))) + if *next != nil { + current = *next + } else { + current = nil + } + depth++ + } + + return nil, false +} + +func (t *tree) Minimum() (value Value, found bool) { + if t == nil || t.root == nil { + return nil, false + } + + leaf := t.root.minimum() + + return leaf.value, true +} + +func (t *tree) Maximum() (value Value, found bool) { + if t == nil || t.root == nil { + return nil, false + } + + leaf := t.root.maximum() + + return leaf.value, true +} + +func (t *tree) Size() int { + if t == nil || t.root == nil { + return 0 + } + + return t.size +} + +func (t *tree) recursiveInsert(curNode **artNode, key Key, value Value, depth uint32) (Value, bool) { + current := *curNode + if current == nil { + replaceRef(curNode, factory.newLeaf(key, value)) + return nil, false + } + + if current.isLeaf() { + leaf := current.leaf() + + // update exists value + if leaf.match(key) { + oldValue := leaf.value + leaf.value = value + return oldValue, true + } + // new value, split the leaf into new node4 + newLeaf := factory.newLeaf(key, value) + leaf2 := newLeaf.leaf() + leafsLCP := t.longestCommonPrefix(leaf, leaf2, depth) + + newNode := factory.newNode4() + newNode.setPrefix(key[depth:], leafsLCP) + depth += leafsLCP + + newNode.addChild(leaf.key.charAt(int(depth)), leaf.key.valid(int(depth)), current) + newNode.addChild(leaf2.key.charAt(int(depth)), leaf2.key.valid(int(depth)), newLeaf) + replaceRef(curNode, newNode) + + return nil, false + } + + node := current.node() + if node.prefixLen > 0 { + prefixMismatchIdx := current.matchDeep(key, depth) + if prefixMismatchIdx >= node.prefixLen { + depth += node.prefixLen + goto NEXT_NODE + } + + newNode := factory.newNode4() + node4 := newNode.node() + node4.prefixLen = prefixMismatchIdx + for i := 0; i < int(min(prefixMismatchIdx, MaxPrefixLen)); i++ { + node4.prefix[i] = node.prefix[i] + } + + if node.prefixLen <= MaxPrefixLen { + node.prefixLen -= (prefixMismatchIdx + 1) + newNode.addChild(node.prefix[prefixMismatchIdx], true, current) + + for i, limit := uint32(0), min(node.prefixLen, MaxPrefixLen); i < limit; i++ { + node.prefix[i] = node.prefix[prefixMismatchIdx+i+1] + } + + } else { + node.prefixLen -= (prefixMismatchIdx + 1) + leaf := current.minimum() + newNode.addChild(leaf.key.charAt(int(depth+prefixMismatchIdx)), leaf.key.valid(int(depth+prefixMismatchIdx)), current) + + for i, limit := uint32(0), min(node.prefixLen, MaxPrefixLen); i < limit; i++ { + node.prefix[i] = leaf.key[depth+prefixMismatchIdx+i+1] + } + } + + // Insert the new leaf + newNode.addChild(key.charAt(int(depth+prefixMismatchIdx)), key.valid(int(depth+prefixMismatchIdx)), factory.newLeaf(key, value)) + replaceRef(curNode, newNode) + + return nil, false + } + +NEXT_NODE: + + // Find a child to recursive to + next := current.findChild(key.charAt(int(depth)), key.valid(int(depth))) + if *next != nil { + return t.recursiveInsert(next, key, value, depth+1) + } + + // No Child, artNode goes with us + current.addChild(key.charAt(int(depth)), key.valid(int(depth)), factory.newLeaf(key, value)) + + return nil, false +} + +func (t *tree) recursiveDelete(curNode **artNode, key Key, depth uint32) (Value, bool) { + if t == nil || *curNode == nil || len(key) == 0 { + return nil, false + } + + current := *curNode + if current.isLeaf() { + leaf := current.leaf() + if leaf.match(key) { + replaceRef(curNode, nil) + return leaf.value, true + } + + return nil, false + } + + node := current.node() + if node.prefixLen > 0 { + prefixLen := current.match(key, depth) + if prefixLen != min(node.prefixLen, MaxPrefixLen) { + return nil, false + } + + depth += node.prefixLen + } + + next := current.findChild(key.charAt(int(depth)), key.valid(int(depth))) + if *next == nil { + return nil, false + } + + if (*next).isLeaf() { + leaf := (*next).leaf() + if leaf.match(key) { + current.deleteChild(key.charAt(int(depth)), key.valid(int(depth))) + return leaf.value, true + } + + return nil, false + } + + return t.recursiveDelete(next, key, depth+1) +} + +func (t *tree) longestCommonPrefix(l1 *leaf, l2 *leaf, depth uint32) uint32 { + l1key, l2key := l1.key, l2.key + idx, limit := depth, min(uint32(len(l1key)), uint32(len(l2key))) + for ; idx < limit; idx++ { + if l1key[idx] != l2key[idx] { + break + } + } + + return idx - depth +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/tree_dump.go b/vendor/github.com/plar/go-adaptive-radix-tree/tree_dump.go new file mode 100644 index 0000000..704e821 --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/tree_dump.go @@ -0,0 +1,306 @@ +package art + +import ( + "bytes" + "fmt" +) + +const ( + printValuesAsChar = 1 << iota + printValuesAsDecimal + printValuesAsHex + + printValueDefault = printValuesAsChar +) + +type depthStorage struct { + childNum int + childrenTotal int +} + +type treeStringer struct { + storage []depthStorage + buf *bytes.Buffer +} + +// String returns tree in the human readable format, see DumpNode for examples +func (t *tree) String() string { + return DumpNode(t.root) +} + +func (ts *treeStringer) generatePads(depth int, childNum int, childrenTotal int) (pad0, pad string) { + ts.storage[depth] = depthStorage{childNum, childrenTotal} + + for d := 0; d <= depth; d++ { + if d < depth { + if ts.storage[d].childNum+1 < ts.storage[d].childrenTotal { + pad0 += "│ " + } else { + pad0 += " " + } + } else { + if childrenTotal == 0 { + pad0 += "─" + } else if ts.storage[d].childNum+1 < ts.storage[d].childrenTotal { + pad0 += "├" + } else { + pad0 += "└" + } + pad0 += "──" + } + + } + pad0 += " " + + for d := 0; d <= depth; d++ { + if childNum+1 < childrenTotal && childrenTotal > 0 { + if ts.storage[d].childNum+1 < ts.storage[d].childrenTotal { + pad += "│ " + } else { + pad += " " + } + } else if d < depth && ts.storage[d].childNum+1 < ts.storage[d].childrenTotal { + pad += "│ " + } else { + pad += " " + } + + } + + return +} + +func (ts *treeStringer) append(v interface{}, opts ...int) *treeStringer { + options := 0 + for _, opt := range opts { + options |= opt + } + + if options == 0 { + options = printValueDefault + } + + switch v.(type) { + + case string: + str, _ := v.(string) + ts.buf.WriteString(str) + + case []byte: + arr, _ := v.([]byte) + ts.append("[") + for i, b := range arr { + if (options & printValuesAsChar) != 0 { + if b > 0 { + ts.append(fmt.Sprintf("%c", b)) + } else { + ts.append("·") + } + + } else if (options & printValuesAsDecimal) != 0 { + ts.append(fmt.Sprintf("%d", b)) + } + if (options&printValuesAsDecimal) != 0 && i+1 < len(arr) { + ts.append(" ") + } + } + ts.append("]") + + case Key: + k, _ := v.(Key) + ts.append([]byte(k)) + + default: + ts.append("[") + ts.append(fmt.Sprintf("%#v", v)) + ts.append("]") + } + + return ts +} + +func (ts *treeStringer) appendKey(keys []byte, present []byte, opts ...int) *treeStringer { + options := 0 + for _, opt := range opts { + options |= opt + } + + if options == 0 { + options = printValueDefault + } + + ts.append("[") + for i, b := range keys { + if (options & printValuesAsChar) != 0 { + if present[i] != 0 { + ts.append(fmt.Sprintf("%c", b)) + } else { + ts.append("·") + } + + } else if (options & printValuesAsDecimal) != 0 { + if present[i] != 0 { + ts.append(fmt.Sprintf("%2d", b)) + } else { + ts.append("·") + } + } else if (options & printValuesAsHex) != 0 { + if present[i] != 0 { + ts.append(fmt.Sprintf("%2x", b)) + } else { + ts.append("·") + } + } + if (options&(printValuesAsDecimal|printValuesAsHex)) != 0 && i+1 < len(keys) { + ts.append(" ") + } + } + ts.append("]") + + return ts +} + +func (ts *treeStringer) children(children []*artNode, numChildred uint16, depth int, zeroChild *artNode) { + for i, child := range children { + ts.baseNode(child, depth, i, len(children)+1) + } + + ts.baseNode(zeroChild, depth, len(children)+1, len(children)+1) +} + +func (ts *treeStringer) node(pad string, prefixLen uint32, prefix []byte, keys []byte, present []byte, children []*artNode, numChildren uint16, depth int, zeroChild *artNode) { + if prefix != nil { + ts.append(pad).append(fmt.Sprintf("prefix(%x): %v", prefixLen, prefix)) + ts.append(prefix).append("\n") + } + + if keys != nil { + ts.append(pad).append("keys: ").appendKey(keys, present, printValuesAsDecimal).append(" ") + ts.appendKey(keys, present, printValuesAsChar).append("\n") + } + + ts.append(pad).append(fmt.Sprintf("children(%v): %+v <%v>\n", numChildren, children, zeroChild)) + ts.children(children, numChildren, depth+1, zeroChild) +} + +func (ts *treeStringer) baseNode(an *artNode, depth int, childNum int, childrenTotal int) { + padHeader, pad := ts.generatePads(depth, childNum, childrenTotal) + if an == nil { + ts.append(padHeader).append("nil").append("\n") + return + } + + ts.append(padHeader) + ts.append(fmt.Sprintf("%v (%p)\n", an.kind, an)) + switch an.kind { + case Node4: + nn := an.node4() + + ts.node(pad, nn.prefixLen, nn.prefix[:], nn.keys[:], nn.present[:], nn.children[:], nn.numChildren, depth, nn.zeroChild) + + case Node16: + nn := an.node16() + + var present []byte + for i := 0; i < len(nn.keys); i++ { + if (nn.present & (1 << uint16(i))) != 0 { + present = append(present, 1) + } else { + present = append(present, 0) + } + } + + ts.node(pad, nn.prefixLen, nn.prefix[:], nn.keys[:], present, nn.children[:], nn.numChildren, depth, nn.zeroChild) + + case Node48: + nn := an.node48() + + var present []byte + for i := 0; i < len(nn.keys); i++ { + if (nn.present[uint16(i)>>n48s] & (1 << (uint16(i) % n48m))) != 0 { + present = append(present, 1) + } else { + present = append(present, 0) + } + } + + ts.node(pad, nn.prefixLen, nn.prefix[:], nn.keys[:], present, nn.children[:], nn.numChildren, depth, nn.zeroChild) + + case Node256: + nn := an.node256() + ts.node(pad, nn.prefixLen, nn.prefix[:], nil, nil, nn.children[:], nn.numChildren, depth, nn.zeroChild) + + case Leaf: + n := an.leaf() + ts.append(pad).append(fmt.Sprintf("key(%d): %v ", len(n.key), n.key)).append(n.key[:]).append("\n") + + if s, ok := n.value.(string); ok { + ts.append(pad).append(fmt.Sprintf("val: %v\n", s)) + } else if b, ok := n.value.([]byte); ok { + ts.append(pad).append(fmt.Sprintf("val: %v\n", string(b))) + } else { + ts.append(pad).append(fmt.Sprintf("val: %v\n", n.value)) + } + + } + + ts.append(pad).append("\n") +} + +func (ts *treeStringer) rootNode(an *artNode) { + ts.baseNode(an, 0, 0, 0) +} + +/* +DumpNode returns Tree in the human readable format: + package main + + import ( + "fmt" + "github.com/plar/go-adaptive-radix-tree" + ) + + func main() { + tree := art.New() + terms := []string{"A", "a", "aa"} + for _, term := range terms { + tree.Insert(art.Key(term), term) + } + fmt.Println(tree) + } + + Output: + ─── Node4 (0xc00008a240) + prefix(0): [0 0 0 0 0 0 0 0 0 0][··········] + keys: [65 97 · ·] [Aa··] + children(2): [0xc00008a210 0xc00008a270 ] + ├── Leaf (0xc00008a210) + │ key(1): [65] [A] + │ val: A + │ + ├── Node4 (0xc00008a270) + │ prefix(0): [0 0 0 0 0 0 0 0 0 0][··········] + │ keys: [97 · · ·] [a···] + │ children(1): [0xc00008a260 0xc00008a230] + │ ├── Leaf (0xc00008a260) + │ │ key(2): [97 97] [aa] + │ │ val: aa + │ │ + │ ├── nil + │ ├── nil + │ ├── nil + │ └── Leaf (0xc00008a230) + │ key(1): [97] [a] + │ val: a + │ + │ + ├── nil + ├── nil + └── nil +*/ +func DumpNode(root *artNode) string { + ts := &treeStringer{make([]depthStorage, 4096), bytes.NewBufferString("")} + ts.rootNode(root) + + return ts.buf.String() +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/tree_traversal.go b/vendor/github.com/plar/go-adaptive-radix-tree/tree_traversal.go new file mode 100644 index 0000000..fc954ff --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/tree_traversal.go @@ -0,0 +1,352 @@ +package art + +type traverseAction int + +const ( + traverseStop traverseAction = iota + traverseContinue +) + +type iteratorLevel struct { + node *artNode + childIdx int +} + +type iterator struct { + version int // tree version + + tree *tree + nextNode *artNode + depthLevel int + depth []*iteratorLevel +} + +type bufferedIterator struct { + options int + nextNode Node + err error + it *iterator +} + +func traverseOptions(opts ...int) int { + options := 0 + for _, opt := range opts { + options |= opt + } + options &= TraverseAll + if options == 0 { + // By default filter only leafs + options = TraverseLeaf + } + + return options +} + +func traverseFilter(options int, callback Callback) Callback { + if options == TraverseAll { + return callback + } + + return func(node Node) bool { + if options&TraverseLeaf == TraverseLeaf && node.Kind() == Leaf { + return callback(node) + } else if options&TraverseNode == TraverseNode && node.Kind() != Leaf { + return callback(node) + } + + return true + } +} + +func (t *tree) ForEach(callback Callback, opts ...int) { + options := traverseOptions(opts...) + t.recursiveForEach(t.root, traverseFilter(options, callback)) +} + +func (t *tree) recursiveForEach(current *artNode, callback Callback) traverseAction { + if current == nil { + return traverseContinue + } + + if !callback(current) { + return traverseStop + } + + switch current.kind { + case Node4: + return t.forEachChildren(current.node().zeroChild, current.node4().children[:], callback) + + case Node16: + return t.forEachChildren(current.node().zeroChild, current.node16().children[:], callback) + + case Node48: + node := current.node48() + child := node.zeroChild + if child != nil { + if t.recursiveForEach(child, callback) == traverseStop { + return traverseStop + } + } + + for i, c := range node.keys { + if node.present[uint16(i)>>n48s]&(1<<(uint16(i)%n48m)) == 0 { + continue + } + + child := node.children[c] + if child != nil { + if t.recursiveForEach(child, callback) == traverseStop { + return traverseStop + } + } + } + + case Node256: + return t.forEachChildren(current.node().zeroChild, current.node256().children[:], callback) + } + + return traverseContinue +} + +func (t *tree) forEachChildren(nullChild *artNode, children []*artNode, callback Callback) traverseAction { + if nullChild != nil { + if t.recursiveForEach(nullChild, callback) == traverseStop { + return traverseStop + } + } + + for _, child := range children { + if child != nil && child != nullChild { + if t.recursiveForEach(child, callback) == traverseStop { + return traverseStop + } + } + } + + return traverseContinue +} + +func (t *tree) ForEachPrefix(key Key, callback Callback) { + t.forEachPrefix(t.root, key, callback) +} + +func (t *tree) forEachPrefix(current *artNode, key Key, callback Callback) traverseAction { + if current == nil { + return traverseContinue + } + + depth := uint32(0) + for current != nil { + if current.isLeaf() { + leaf := current.leaf() + if leaf.prefixMatch(key) { + if !callback(current) { + return traverseStop + } + } + break + } + + if depth == uint32(len(key)) { + leaf := current.minimum() + if leaf.prefixMatch(key) { + if t.recursiveForEach(current, callback) == traverseStop { + return traverseStop + } + } + break + } + + node := current.node() + if node.prefixLen > 0 { + prefixLen := current.matchDeep(key, depth) + if prefixLen > node.prefixLen { + prefixLen = node.prefixLen + } + + if prefixLen == 0 { + break + } else if depth+prefixLen == uint32(len(key)) { + return t.recursiveForEach(current, callback) + + } + depth += node.prefixLen + } + + // Find a child to recursive to + next := current.findChild(key.charAt(int(depth)), key.valid(int(depth))) + if *next == nil { + break + } + current = *next + depth++ + } + + return traverseContinue +} + +// Iterator pattern +func (t *tree) Iterator(opts ...int) Iterator { + options := traverseOptions(opts...) + + it := &iterator{ + version: t.version, + tree: t, + nextNode: t.root, + depthLevel: 0, + depth: []*iteratorLevel{{t.root, nullIdx}}} + + if options&TraverseAll == TraverseAll { + return it + } + + bti := &bufferedIterator{ + options: options, + it: it, + } + return bti +} + +func (ti *iterator) checkConcurrentModification() error { + if ti.version == ti.tree.version { + return nil + } + + return ErrConcurrentModification +} + +func (ti *iterator) HasNext() bool { + return ti != nil && ti.nextNode != nil +} + +func (ti *iterator) Next() (Node, error) { + if !ti.HasNext() { + return nil, ErrNoMoreNodes + } + + err := ti.checkConcurrentModification() + if err != nil { + return nil, err + } + + cur := ti.nextNode + ti.next() + + return cur, nil +} + +const nullIdx = -1 + +func nextChild(childIdx int, nullChild *artNode, children []*artNode) ( /*nextChildIdx*/ int /*nextNode*/, *artNode) { + if childIdx == nullIdx { + if nullChild != nil { + return 0, nullChild + } + + childIdx = 0 + } + + for i := childIdx; i < len(children); i++ { + child := children[i] + if child != nil && child != nullChild { + return i + 1, child + } + } + + return 0, nil +} + +func (ti *iterator) next() { + for { + var nextNode *artNode + nextChildIdx := nullIdx + + curNode := ti.depth[ti.depthLevel].node + curChildIdx := ti.depth[ti.depthLevel].childIdx + + switch curNode.kind { + case Node4: + nextChildIdx, nextNode = nextChild(curChildIdx, curNode.node().zeroChild, curNode.node4().children[:]) + + case Node16: + nextChildIdx, nextNode = nextChild(curChildIdx, curNode.node().zeroChild, curNode.node16().children[:]) + + case Node48: + node := curNode.node48() + nullChild := node.zeroChild + if curChildIdx == nullIdx { + if nullChild == nil { + curChildIdx = 0 // try from 0 based child + } else { + nextChildIdx = 0 // we have a child with null suffix + nextNode = nullChild + break + } + } + + for i := curChildIdx; i < len(node.keys); i++ { + // if node.present[i] == 0 { + if node.present[uint16(i)>>n48s]&(1<<(uint16(i)%n48m)) == 0 { + continue + } + + child := node.children[node.keys[i]] + if child != nil && child != nullChild { + nextChildIdx = i + 1 + nextNode = child + break + } + } + + case Node256: + nextChildIdx, nextNode = nextChild(curChildIdx, curNode.node().zeroChild, curNode.node256().children[:]) + } + + if nextNode == nil { + if ti.depthLevel > 0 { + // return to previous level + ti.depthLevel-- + } else { + ti.nextNode = nil // done! + return + } + } else { + // star from the next when we come back from the child node + ti.depth[ti.depthLevel].childIdx = nextChildIdx + ti.nextNode = nextNode + + // make sure that we have enough space for levels + if ti.depthLevel+1 >= cap(ti.depth) { + newDepthLevel := make([]*iteratorLevel, ti.depthLevel+2) + copy(newDepthLevel, ti.depth) + ti.depth = newDepthLevel + } + + ti.depthLevel++ + ti.depth[ti.depthLevel] = &iteratorLevel{nextNode, nullIdx} + return + } + } +} + +func (bti *bufferedIterator) HasNext() bool { + for bti.it.HasNext() { + bti.nextNode, bti.err = bti.it.Next() + if bti.err != nil { + return true + } + if bti.options&TraverseLeaf == TraverseLeaf && bti.nextNode.Kind() == Leaf { + return true + } else if bti.options&TraverseNode == TraverseNode && bti.nextNode.Kind() != Leaf { + return true + } + } + bti.nextNode = nil + bti.err = nil + + return false +} + +func (bti *bufferedIterator) Next() (Node, error) { + return bti.nextNode, bti.err +} diff --git a/vendor/github.com/plar/go-adaptive-radix-tree/utils.go b/vendor/github.com/plar/go-adaptive-radix-tree/utils.go new file mode 100644 index 0000000..9f1cc7e --- /dev/null +++ b/vendor/github.com/plar/go-adaptive-radix-tree/utils.go @@ -0,0 +1,8 @@ +package art + +func min(a, b uint32) uint32 { + if a < b { + return a + } + return b +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0cd7743..2c0fe2f 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -26,6 +26,9 @@ github.com/google/cel-go/common/types/traits github.com/google/cel-go/interpreter github.com/google/cel-go/parser github.com/google/cel-go/parser/gen +# github.com/plar/go-adaptive-radix-tree v1.0.5 +## explicit; go 1.12 +github.com/plar/go-adaptive-radix-tree # github.com/pmezard/go-difflib v1.0.0 ## explicit github.com/pmezard/go-difflib/difflib