diff --git a/ast/compile.go b/ast/compile.go index 46954f7037..acf688a157 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -139,6 +139,7 @@ func NewCompiler() *Compiler { stage{c.setRuleTree, "setRuleTree"}, stage{c.setRuleGraph, "setRuleGraph"}, stage{c.rewriteRefsInHead, "rewriteRefsInHead"}, + stage{c.checkRuleConflicts, "checkRuleConflicts"}, stage{c.checkBuiltins, "checkBuiltins"}, stage{c.checkSafetyRuleHeads, "checkSafetyRuleHeads"}, stage{c.checkSafetyRuleBodies, "checkSafetyRuleBodies"}, @@ -297,6 +298,27 @@ func (c *Compiler) checkRecursion() { } } +// checkRuleConflicts ensures that rules definitions are not in conflict. +func (c *Compiler) checkRuleConflicts() { + c.RuleTree.DepthFirst(func(node *RuleTreeNode) bool { + if len(node.Rules) == 0 { + return false + } + + kinds := map[DocKind]struct{}{} + for _, rule := range node.Rules { + kinds[rule.DocKind()] = struct{}{} + } + + if len(kinds) > 1 { + name := Var(node.Key.(String)) + c.err(NewError(CompileErr, node.Rules[0].Loc(), "%v: conflicting rule types (all definitions of %v must have the same type)", name, name)) + } + + return false + }) +} + // checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target // positions of built-in expressions will be bound when evaluating the rule from left // to right, re-ordering as necessary. @@ -651,6 +673,16 @@ func (n *RuleTreeNode) Size() int { return s } +// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If +// f returns true, traversal will not continue to the children of n. +func (n *RuleTreeNode) DepthFirst(f func(node *RuleTreeNode) bool) { + if !f(n) { + for _, node := range n.Children { + node.DepthFirst(f) + } + } +} + // builtinChecker verifies that built-in functions are called correctly. type builtinChecker struct { errors *Errors diff --git a/ast/compile_test.go b/ast/compile_test.go index 713b60863b..f7e083ff76 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -379,7 +379,7 @@ func TestCompilerCheckSafetyBodyErrors(t *testing.T) { } -func TestCompilerBuiltins(t *testing.T) { +func TestCompilerCheckBuiltins(t *testing.T) { c := NewCompiler() c.Modules = map[string]*Module{ "mod": MustParseModule(` @@ -390,20 +390,35 @@ func TestCompilerBuiltins(t *testing.T) { `), } compileStages(c, "", "checkBuiltins") - result := compilerErrsToStringSlice(c.Errors) + expected := []string{ "p: wrong number of arguments (expression count(1) must specify 2 arguments to built-in function count)", "q: wrong number of arguments (expression count([1,2,3], x, 1) must specify 2 arguments to built-in function count)", "r: unknown built-in function deadbeef", } - if len(result) != len(expected) { - t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) + + assertCompilerErrorStrings(t, c, expected) +} + +func TestCompilerCheckRuleConflicts(t *testing.T) { + c := NewCompiler() + c.Modules = map[string]*Module{ + "mod": MustParseModule(` + package badrules + p[x] :- x=1 + p[x] = y :- x = y, x = 1 + q[1] :- true + q = {1,2,3} :- true + `), } - for i := range result { - if expected[i] != result[i] { - t.Errorf("Expected %v but got: %v", expected[i], result[i]) - } + compileStages(c, "", "checkRuleConflicts") + + expected := []string{ + "p: conflicting rule types (all definitions of p must have the same type)", + "q: conflicting rule types (all definitions of q must have the same type)", } + + assertCompilerErrorStrings(t, c, expected) } func TestCompilerResolveAllRefs(t *testing.T) { @@ -754,7 +769,7 @@ func TestCompilerGetRulesWithPrefix(t *testing.T) { package a.b.c p[1] :- true p[2] :- true - q = true + q[3] :- true `) c := NewCompiler() @@ -834,6 +849,18 @@ func TestQueryCompiler(t *testing.T) { } } +func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []string) { + result := compilerErrsToStringSlice(compiler.Errors) + if len(result) != len(expected) { + t.Fatalf("Expected %d:\n%v\nBut got %d:\n%v", len(expected), strings.Join(expected, "\n"), len(result), strings.Join(result, "\n")) + } + for i := range result { + if expected[i] != result[i] { + t.Errorf("Expected %v but got: %v", expected[i], result[i]) + } + } +} + func assertNotFailed(t *testing.T, c *Compiler) { if c.Failed() { t.Errorf("Unexpected compilation error: %v", c.Errors)