Skip to content

Commit

Permalink
Add compiler check for consistent rule types
Browse files Browse the repository at this point in the history
Previously, rules could be defined multiple times with different types. This
was not handled in topdown and would result in a panic.

Fixes open-policy-agent#147
  • Loading branch information
tsandall committed Nov 24, 2016
1 parent 9728d35 commit 4aae225
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
32 changes: 32 additions & 0 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func NewCompiler() *Compiler {
stage{c.setRuleTree, "setRuleTree"},
stage{c.setRuleGraph, "setRuleGraph"},
stage{c.rewriteRefsInHead, "rewriteRefsInHead"},
stage{c.checkRuleConflicts, "checkRuleConflicts"},
stage{c.checkBuiltins, "checkBuiltins"},
stage{c.checkSafetyRuleHeads, "checkSafetyRuleHeads"},
stage{c.checkSafetyRuleBodies, "checkSafetyRuleBodies"},
Expand Down Expand Up @@ -297,6 +298,27 @@ func (c *Compiler) checkRecursion() {
}
}

// checkRuleConflicts ensures that rules definitions are not in conflict.
func (c *Compiler) checkRuleConflicts() {
c.RuleTree.DepthFirst(func(node *RuleTreeNode) bool {
if len(node.Rules) == 0 {
return false
}

kinds := map[DocKind]struct{}{}
for _, rule := range node.Rules {
kinds[rule.DocKind()] = struct{}{}
}

if len(kinds) > 1 {
name := Var(node.Key.(String))
c.err(NewError(CompileErr, node.Rules[0].Loc(), "%v: conflicting rule types (all definitions of %v must have the same type)", name, name))
}

return false
})
}

// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
// positions of built-in expressions will be bound when evaluating the rule from left
// to right, re-ordering as necessary.
Expand Down Expand Up @@ -651,6 +673,16 @@ func (n *RuleTreeNode) Size() int {
return s
}

// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
// f returns true, traversal will not continue to the children of n.
func (n *RuleTreeNode) DepthFirst(f func(node *RuleTreeNode) bool) {
if !f(n) {
for _, node := range n.Children {
node.DepthFirst(f)
}
}
}

// builtinChecker verifies that built-in functions are called correctly.
type builtinChecker struct {
errors *Errors
Expand Down
45 changes: 36 additions & 9 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) {

}

func TestCompilerBuiltins(t *testing.T) {
func TestCompilerCheckBuiltins(t *testing.T) {
c := NewCompiler()
c.Modules = map[string]*Module{
"mod": MustParseModule(`
Expand All @@ -390,20 +390,35 @@ func TestCompilerBuiltins(t *testing.T) {
`),
}
compileStages(c, "", "checkBuiltins")
result := compilerErrsToStringSlice(c.Errors)

expected := []string{
"p: wrong number of arguments (expression count(1) must specify 2 arguments to built-in function count)",
"q: wrong number of arguments (expression count([1,2,3], x, 1) must specify 2 arguments to built-in function count)",
"r: unknown built-in function deadbeef",
}
if len(result) != len(expected) {
t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n"))

assertCompilerErrorStrings(t, c, expected)
}

func TestCompilerCheckRuleConflicts(t *testing.T) {
c := NewCompiler()
c.Modules = map[string]*Module{
"mod": MustParseModule(`
package badrules
p[x] :- x=1
p[x] = y :- x = y, x = 1
q[1] :- true
q = {1,2,3} :- true
`),
}
for i := range result {
if expected[i] != result[i] {
t.Errorf("Expected %v but got: %v", expected[i], result[i])
}
compileStages(c, "", "checkRuleConflicts")

expected := []string{
"p: conflicting rule types (all definitions of p must have the same type)",
"q: conflicting rule types (all definitions of q must have the same type)",
}

assertCompilerErrorStrings(t, c, expected)
}

func TestCompilerResolveAllRefs(t *testing.T) {
Expand Down Expand Up @@ -754,7 +769,7 @@ func TestCompilerGetRulesWithPrefix(t *testing.T) {
package a.b.c
p[1] :- true
p[2] :- true
q = true
q[3] :- true
`)

c := NewCompiler()
Expand Down Expand Up @@ -834,6 +849,18 @@ func TestQueryCompiler(t *testing.T) {
}
}

func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []string) {
result := compilerErrsToStringSlice(compiler.Errors)
if len(result) != len(expected) {
t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n"))
}
for i := range result {
if expected[i] != result[i] {
t.Errorf("Expected %v but got: %v", expected[i], result[i])
}
}
}

func assertNotFailed(t *testing.T, c *Compiler) {
if c.Failed() {
t.Errorf("Unexpected compilation error: %v", c.Errors)
Expand Down

0 comments on commit 4aae225

Please sign in to comment.