Skip to content

Commit

Permalink
ast: Fix rewriting vars in rule args
Browse files Browse the repository at this point in the history
Vars declared as rule args were not being rewritten which lead to
namespace conflicts with built-in functions. For example, definitions
like `f(object) { object[x] = 1 }` would generate type checking errors
because the checker would treat `object[x] = 1` as referring to the
`object` built-in function namespace.

This change updates the compiler to rewrite rule arguments just like
it does other declared variables.

Fixes open-policy-agent#2080

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
tsandall committed Feb 9, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d4c708c commit ceb5dc1
Showing 2 changed files with 125 additions and 19 deletions.
72 changes: 66 additions & 6 deletions ast/compile.go
Original file line number Diff line number Diff line change
@@ -1092,7 +1092,10 @@ func (c *Compiler) rewriteLocalVars() {
}

stack := newLocalDeclaredVars()
body, declared, errs := rewriteLocalVars(gen, stack, rule.Head.Args.Vars(), used, rule.Body)

c.rewriteLocalArgVars(gen, stack, rule)

body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body)
for _, err := range errs {
c.err(err)
}
@@ -1151,6 +1154,66 @@ func (c *Compiler) rewriteLocalVars() {
}
}

func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {

vis := &ruleArgLocalRewriter{
stack: stack,
gen: gen,
}

for i := range rule.Head.Args {
Walk(vis, rule.Head.Args[i])
}

for i := range vis.errs {
c.err(vis.errs[i])
}
}

type ruleArgLocalRewriter struct {
stack *localDeclaredVars
gen *localVarGenerator
errs []*Error
}

func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {

t, ok := x.(*Term)
if !ok {
return vis
}

switch v := t.Value.(type) {
case Var:
gv, ok := vis.stack.Declared(v)
if !ok {
gv = vis.gen.Generate()
vis.stack.Insert(v, gv, argVar)
}
t.Value = gv
return nil
case Object:
if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
vcpy := v.Copy()
Walk(vis, vcpy)
return k, vcpy, nil
}); err != nil {
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = cpy
}
return nil
case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
// Scalars are no-ops. Comprehensions are handled above. Sets must not
// contain variables.
return nil
default:
// Recurse on refs, arrays, and calls. Any embedded
// variables can be rewritten.
return vis
}
}

func (c *Compiler) rewriteWithModifiers() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
@@ -1333,7 +1396,7 @@ func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, err
func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
stack := newLocalDeclaredVars()
body, _, err := rewriteLocalVars(gen, stack, nil, nil, body)
body, _, err := rewriteLocalVars(gen, stack, nil, body)
if len(err) != 0 {
return nil, err
}
@@ -2910,10 +2973,7 @@ func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
// __local0__ = 1; p[__local0__]
//
// During rewriting, assignees are validated to prevent use before declaration.
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, args VarSet, used VarSet, body Body) (Body, map[Var]Var, Errors) {
for v := range args {
stack.Insert(v, v, argVar)
}
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) {
var errs Errors
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs)
return body, stack.Pop().vs, errs
72 changes: 59 additions & 13 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
@@ -956,17 +956,17 @@ func TestCompilerRewriteExprTerms(t *testing.T) {
expected := MustParseModule(`
package test
p { mul(b, y, __local0__); plus(a, __local0__, __local1__); eq(x, __local1__) }
p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) }
q[[__local2__]] { x = 1; data.test.f(x, __local2__) }
q[[__local3__]] { x = 1; data.test.f(x, __local3__) }
r = [__local3__] { x = 1; data.test.f(x, __local3__) }
r = [__local4__] { x = 1; data.test.f(x, __local4__) }
f(x) = __local4__ { true; data.test.g(x, __local4__) }
f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) }
pi = __local5__ { true; plus(3, 0.14, __local5__) }
pi = __local6__ { true; plus(3, 0.14, __local6__) }
with_value { data.test.f(1, __local6__); 1 with input as __local6__ }
with_value { data.test.f(1, __local7__); 1 with input as __local7__ }
`)

if !expected.Equal(compiler.Modules["test"]) {
@@ -1233,7 +1233,7 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {

tests := []struct {
module string
exp string
exp interface{}
expRewrittenMap map[Var]Var
}{
{
@@ -1256,10 +1256,11 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
`,
exp: `
package test
head_vars(a) = __local0__ { __local0__ = a }
head_vars(__local0__) = __local1__ { __local1__ = __local0__ }
`,
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("b"),
Var("__local0__"): Var("a"),
Var("__local1__"): Var("b"),
},
},
{
@@ -1580,16 +1581,53 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
y := 4
}
`,
// Each "else" rule has a separate rule head and the vars in the
// args will be rewritten. Since we cannot currently redefine the
// args, we must parse the module and then manually update the args.
exp: func() *Module {
module := MustParseModule(`
package test
f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local3__ { __local2__ == 3; __local3__ = 4 }
`)
module.Rules[0].Else.Head.Args[0].Value = Var("__local2__")
return module
},
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("x"),
Var("__local1__"): Var("y"),
Var("__local2__"): Var("x"),
Var("__local3__"): Var("y"),
},
},
{
module: `
package test
f({"x": [x]}) = y { x == 1; y := 2 }`,
exp: `
package test
f(x) = __local0__ { x == 1; __local0__ = 2 } else = __local1__ { x == 3; __local1__ = 4 }
`,
f({"x": [__local0__]}) = __local1__ { __local0__ == 1; __local1__ = 2 }`,
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("y"),
Var("__local0__"): Var("x"),
Var("__local1__"): Var("y"),
},
},
{
module: `
package test
f(x, [x]) = x { x == 1 }
`,
exp: `
package test
f(__local0__, [__local0__]) = __local0__ { __local0__ == 1 }
`,
expRewrittenMap: map[Var]Var{
Var("__local0__"): Var("x"),
},
},
}

for i, tc := range tests {
@@ -1601,7 +1639,15 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) {
compileStages(c, c.rewriteLocalVars)
assertNotFailed(t, c)
result := c.Modules["test.rego"]
exp := MustParseModule(tc.exp)
var exp *Module
switch e := tc.exp.(type) {
case string:
exp = MustParseModule(e)
case func() *Module:
exp = e()
default:
panic("expected value must be string or func() *Module")
}
if result.Compare(exp) != 0 {
t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result)
}

0 comments on commit ceb5dc1

Please sign in to comment.