From ceb5dc12b37c65f117070078f9ef32a7e7e9551e Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Sun, 9 Feb 2020 08:57:04 -0500 Subject: [PATCH] ast: Fix rewriting vars in rule args Vars declared as rule args were not being rewritten which lead to namespace conflicts with built-in functions. For example, definitions like `f(object) { object[x] = 1 }` would generate type checking errors because the checker would treat `object[x] = 1` as referring to the `object` built-in function namespace. This change updates the compiler to rewrite rule arguments just like it does other declared variables. Fixes #2080 Signed-off-by: Torin Sandall --- ast/compile.go | 72 +++++++++++++++++++++++++++++++++++++++++---- ast/compile_test.go | 72 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 125 insertions(+), 19 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index ae43899d14..56edbe52c3 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -1092,7 +1092,10 @@ func (c *Compiler) rewriteLocalVars() { } stack := newLocalDeclaredVars() - body, declared, errs := rewriteLocalVars(gen, stack, rule.Head.Args.Vars(), used, rule.Body) + + c.rewriteLocalArgVars(gen, stack, rule) + + body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body) for _, err := range errs { c.err(err) } @@ -1151,6 +1154,66 @@ func (c *Compiler) rewriteLocalVars() { } } +func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) { + + vis := &ruleArgLocalRewriter{ + stack: stack, + gen: gen, + } + + for i := range rule.Head.Args { + Walk(vis, rule.Head.Args[i]) + } + + for i := range vis.errs { + c.err(vis.errs[i]) + } +} + +type ruleArgLocalRewriter struct { + stack *localDeclaredVars + gen *localVarGenerator + errs []*Error +} + +func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { + + t, ok := x.(*Term) + if !ok { + return vis + } + + switch v := t.Value.(type) { + case Var: + gv, ok := vis.stack.Declared(v) + if !ok { + gv = vis.gen.Generate() + vis.stack.Insert(v, gv, argVar) + } + t.Value = gv + return nil + case Object: + if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) { + vcpy := v.Copy() + Walk(vis, vcpy) + return k, vcpy, nil + }); err != nil { + vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) + } else { + t.Value = cpy + } + return nil + case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set: + // Scalars are no-ops. Comprehensions are handled above. Sets must not + // contain variables. + return nil + default: + // Recurse on refs, arrays, and calls. Any embedded + // variables can be rewritten. + return vis + } +} + func (c *Compiler) rewriteWithModifiers() { f := newEqualityFactory(c.localvargen) for _, name := range c.sorted { @@ -1333,7 +1396,7 @@ func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, err func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) stack := newLocalDeclaredVars() - body, _, err := rewriteLocalVars(gen, stack, nil, nil, body) + body, _, err := rewriteLocalVars(gen, stack, nil, body) if len(err) != 0 { return nil, err } @@ -2910,10 +2973,7 @@ func (s localDeclaredVars) Occurrence(x Var) varOccurrence { // __local0__ = 1; p[__local0__] // // During rewriting, assignees are validated to prevent use before declaration. -func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, args VarSet, used VarSet, body Body) (Body, map[Var]Var, Errors) { - for v := range args { - stack.Insert(v, v, argVar) - } +func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) { var errs Errors body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs) return body, stack.Pop().vs, errs diff --git a/ast/compile_test.go b/ast/compile_test.go index 7335795eda..9be1760434 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -956,17 +956,17 @@ func TestCompilerRewriteExprTerms(t *testing.T) { expected := MustParseModule(` package test - p { mul(b, y, __local0__); plus(a, __local0__, __local1__); eq(x, __local1__) } + p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) } - q[[__local2__]] { x = 1; data.test.f(x, __local2__) } + q[[__local3__]] { x = 1; data.test.f(x, __local3__) } - r = [__local3__] { x = 1; data.test.f(x, __local3__) } + r = [__local4__] { x = 1; data.test.f(x, __local4__) } - f(x) = __local4__ { true; data.test.g(x, __local4__) } + f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) } - pi = __local5__ { true; plus(3, 0.14, __local5__) } + pi = __local6__ { true; plus(3, 0.14, __local6__) } - with_value { data.test.f(1, __local6__); 1 with input as __local6__ } + with_value { data.test.f(1, __local7__); 1 with input as __local7__ } `) if !expected.Equal(compiler.Modules["test"]) { @@ -1233,7 +1233,7 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { tests := []struct { module string - exp string + exp interface{} expRewrittenMap map[Var]Var }{ { @@ -1256,10 +1256,11 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { `, exp: ` package test - head_vars(a) = __local0__ { __local0__ = a } + head_vars(__local0__) = __local1__ { __local1__ = __local0__ } `, expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("b"), + Var("__local0__"): Var("a"), + Var("__local1__"): Var("b"), }, }, { @@ -1580,16 +1581,53 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { y := 4 } `, + // Each "else" rule has a separate rule head and the vars in the + // args will be rewritten. Since we cannot currently redefine the + // args, we must parse the module and then manually update the args. + exp: func() *Module { + module := MustParseModule(` + package test + + f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local3__ { __local2__ == 3; __local3__ = 4 } + `) + module.Rules[0].Else.Head.Args[0].Value = Var("__local2__") + return module + }, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + Var("__local1__"): Var("y"), + Var("__local2__"): Var("x"), + Var("__local3__"): Var("y"), + }, + }, + { + module: ` + package test + f({"x": [x]}) = y { x == 1; y := 2 }`, exp: ` package test - f(x) = __local0__ { x == 1; __local0__ = 2 } else = __local1__ { x == 3; __local1__ = 4 } - `, + f({"x": [__local0__]}) = __local1__ { __local0__ == 1; __local1__ = 2 }`, expRewrittenMap: map[Var]Var{ - Var("__local0__"): Var("y"), + Var("__local0__"): Var("x"), Var("__local1__"): Var("y"), }, }, + { + module: ` + package test + + f(x, [x]) = x { x == 1 } + `, + exp: ` + package test + + f(__local0__, [__local0__]) = __local0__ { __local0__ == 1 } + `, + expRewrittenMap: map[Var]Var{ + Var("__local0__"): Var("x"), + }, + }, } for i, tc := range tests { @@ -1601,7 +1639,15 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { compileStages(c, c.rewriteLocalVars) assertNotFailed(t, c) result := c.Modules["test.rego"] - exp := MustParseModule(tc.exp) + var exp *Module + switch e := tc.exp.(type) { + case string: + exp = MustParseModule(e) + case func() *Module: + exp = e() + default: + panic("expected value must be string or func() *Module") + } if result.Compare(exp) != 0 { t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result) }