Skip to content

Commit

Permalink
Cost tracking for two-variable comprehensions and bindings (#1104)
Browse files Browse the repository at this point in the history
* Updates to the cost estimators to support bind and two-var comprehensions
* Consolidation of local variables
  • Loading branch information
TristonianJones authored Jan 23, 2025
1 parent 7621362 commit 2f7606a
Show file tree
Hide file tree
Showing 8 changed files with 942 additions and 323 deletions.
649 changes: 490 additions & 159 deletions checker/cost.go

Large diffs are not rendered by default.

91 changes: 76 additions & 15 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestCost(t *testing.T) {
nestedMap := types.NewMapType(types.StringType, allMap)

zeroCost := CostEstimate{}
oneCost := CostEstimate{Min: 1, Max: 1}
oneCost := FixedCostEstimate(1)
cases := []struct {
name string
expr string
Expand Down Expand Up @@ -255,6 +255,11 @@ func TestCost(t *testing.T) {
expr: `size("123")`,
wanted: oneCost,
},
{
name: "bytes size",
expr: `size(b"123")`,
wanted: oneCost,
},
{
name: "bytes to string conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
Expand Down Expand Up @@ -462,6 +467,36 @@ func TestCost(t *testing.T) {
},
wanted: CostEstimate{Min: 5, Max: 5},
},
{
name: "list size from concat",
expr: `([x, y] + list1 + list2).size()`,
vars: []*decls.VariableDecl{
decls.NewVariable("x", types.IntType),
decls.NewVariable("y", types.IntType),
decls.NewVariable("list1", types.NewListType(types.IntType)),
decls.NewVariable("list2", types.NewListType(types.IntType)),
},
hints: map[string]uint64{
"list1": 10,
"list2": 20,
},
wanted: CostEstimate{Min: 17, Max: 17},
},
{
name: "list cost tracking through comprehension",
expr: `[list1, list2].exists(l, l.exists(v, v.startsWith('hi')))`,
vars: []*decls.VariableDecl{
decls.NewVariable("list1", types.NewListType(types.StringType)),
decls.NewVariable("list2", types.NewListType(types.StringType)),
},
hints: map[string]uint64{
"list1": 10,
"list1.@items": 64,
"list2": 20,
"list2.@items": 128,
},
wanted: CostEstimate{Min: 21, Max: 265},
},
{
name: "str endsWith equality",
expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`,
Expand Down Expand Up @@ -539,27 +574,37 @@ func TestCost(t *testing.T) {
wanted: CostEstimate{Min: 61, Max: 61},
},
{
name: "nested array selection",
name: "nested map selection",
expr: `{'a': [1,2], 'b': [1,2], 'c': [1,2], 'd': [1,2], 'e': [1,2]}.b`,
wanted: CostEstimate{Min: 81, Max: 81},
},
{
// Estimated cost does not track the sizes of nested aggregate types
// (lists, maps, ...) and so assumes a worst case cost when an
// expression applies a comprehension to a nested aggregated type,
// even if the size information is available.
// TODO: This should be fixed.
name: "comprehension on nested list",
expr: `[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]].all(y, y.all(y, y == 1))`,
wanted: CostEstimate{Min: 76, Max: 136},
},
{
name: "comprehension on transformed nested list",
expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`,
wanted: CostEstimate{Min: 157, Max: 18446744073709551615},
wanted: CostEstimate{Min: 157, Max: 217},
},
{
// Make sure we're accounting for not just the iteration range size,
// but also the overall comprehension size. The chained map calls
// will treat the result of one map as the iteration range of the other,
// so they're planned in reverse; however, the `+` should verify that
// the comprehension result has a size.
name: "comprehension size",
name: "comprehension on nested literal list",
expr: `["a", "ab", "abc", "abcd", "abcde"].map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
wanted: CostEstimate{Min: 157, Max: 217},
},
{
name: "comprehension on nested variable list",
expr: `input.map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
vars: []*decls.VariableDecl{decls.NewVariable("input", types.NewListType(types.StringType))},
hints: map[string]uint64{
"input": 5,
"input.@items": 10,
},
wanted: CostEstimate{Min: 13, Max: 208},
},
{
name: "comprehension chaining with concat",
expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`,
wanted: CostEstimate{Min: 173, Max: 173},
},
Expand All @@ -568,9 +613,25 @@ func TestCost(t *testing.T) {
expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`,
wanted: CostEstimate{Min: 20, Max: 230},
},
{
name: "nested dyn comprehension",
expr: `dyn([1,2,3]).all(i, i in dyn([1,2,3]).map(j, j + j))`,
wanted: CostEstimate{Min: 21, Max: 234},
},
{
name: "literal map access",
expr: `{'hello': 'hi'}['hello'] != {'hello': 'bye'}['hello']`,
wanted: CostEstimate{Min: 63, Max: 63},
},
{
name: "literal list access",
expr: `['hello', 'hi'][0] != ['hello', 'bye'][1]`,
wanted: CostEstimate{Min: 23, Max: 23},
},
}

for _, tc := range cases {
for _, tst := range cases {
tc := tst
t.Run(tc.name, func(t *testing.T) {
if tc.hints == nil {
tc.hints = map[string]uint64{}
Expand Down
105 changes: 75 additions & 30 deletions ext/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,56 +20,101 @@ import (
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

var bindingTests = []struct {
expr string
parseOnly bool
name string
expr string
vars []cel.EnvOption
in map[string]any
hints map[string]uint64
estimatedCost checker.CostEstimate
actualCost uint64
}{
{expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) ==
['hell' + 'o' + '!', 'hell' + 'o' + '!', 'hell' + 'o' + '!'].join(', ')`},
// Variable shadowing
{expr: `cel.bind(a,
cel.bind(a, 'world', a + '!'),
'hello ' + a) == 'hello ' + 'world' + '!'`},
{
name: "single bind",
expr: `cel.bind(a, 'hell' + 'o' + '!', "%s, %s, %s".format([a, a, a])) ==
'hello!, hello!, hello' + '!'`,
estimatedCost: checker.CostEstimate{Min: 30, Max: 32},
actualCost: 32,
},
{
name: "multiple binds",
expr: `cel.bind(a, 'hello!',
cel.bind(b, 'goodbye',
a + ' and, ' + b)) == 'hello! and, goodbye'`,
estimatedCost: checker.CostEstimate{Min: 27, Max: 28},
actualCost: 28,
},
{
name: "shadow binds",
expr: `cel.bind(a,
cel.bind(a, 'world', a + '!'),
'hello ' + a) == 'hello ' + 'world' + '!'`,
estimatedCost: checker.CostEstimate{Min: 30, Max: 31},
actualCost: 31,
},
{
name: "nested bind with int list",
expr: `cel.bind(a, x,
cel.bind(b, a[0],
cel.bind(c, a[1], b + c))) == 10`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))},
in: map[string]any{
"x": []int64{3, 7},
},
hints: map[string]uint64{
"x": 3,
},
estimatedCost: checker.CostEstimate{Min: 39, Max: 39},
actualCost: 39,
},
{
name: "nested bind with string list",
expr: `cel.bind(a, x,
cel.bind(b, a[0],
cel.bind(c, a[1], b + c))) == "threeseven"`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.StringType))},
in: map[string]any{
"x": []string{"three", "seven"},
},
hints: map[string]uint64{
"x": 3,
"x.@items": 10,
},
estimatedCost: checker.CostEstimate{Min: 38, Max: 40},
actualCost: 39,
},
}

func TestBindings(t *testing.T) {
env, err := cel.NewEnv(Bindings(BindingsVersion(0)), Strings())
if err != nil {
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
}
for i, tst := range bindingTests {
for _, tst := range bindingTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var asts []*cel.Ast
opts := append([]cel.EnvOption{Bindings(BindingsVersion(0)), Strings()}, tc.vars...)
env, err := cel.NewEnv(opts...)
if err != nil {
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
}
pAst, iss := env.Parse(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, pAst)
if !tc.parseOnly {
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, cAst)
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}
testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost)
asts = append(asts, cAst)
for _, ast := range asts {
prg, err := env.Program(ast)
if err != nil {
t.Fatal(err)
}
out, _, err := prg.Eval(cel.NoVars())
if err != nil {
t.Fatal(err)
} else if out.Value() != true {
t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr)
}
testEvalWithCost(t, env, ast, tc.in, tc.actualCost)
}
})
}
Expand Down
Loading

0 comments on commit 2f7606a

Please sign in to comment.