diff --git a/ast/check.go b/ast/check.go index 6145266599..1e0971bc89 100644 --- a/ast/check.go +++ b/ast/check.go @@ -13,7 +13,7 @@ import ( "github.com/open-policy-agent/opa/util" ) -type rewriteVars func(x Ref) Ref +type varRewriter func(Ref) Ref // exprChecker defines the interface for executing type checking on a single // expression. The exprChecker must update the provided TypeEnv with inferred @@ -26,7 +26,7 @@ type exprChecker func(*TypeEnv, *Expr) *Error type typeChecker struct { errs Errors exprCheckers map[string]exprChecker - varRewriter rewriteVars + varRewriter varRewriter ss *SchemaSet allowNet []string input types.Type @@ -70,7 +70,7 @@ func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker { return tc } -func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker { +func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker { tc.varRewriter = f return tc } @@ -570,10 +570,14 @@ func (tc *typeChecker) err(errors []*Error) { type refChecker struct { env *TypeEnv errs Errors - varRewriter rewriteVars + varRewriter varRewriter } -func newRefChecker(env *TypeEnv, f rewriteVars) *refChecker { +func rewriteVarsNop(node Ref) Ref { + return node +} + +func newRefChecker(env *TypeEnv, f varRewriter) *refChecker { if f == nil { f = rewriteVarsNop diff --git a/ast/compile.go b/ast/compile.go index 37d044d041..eddfbd21c2 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -835,13 +835,13 @@ func (c *Compiler) checkRuleConflicts() { func (c *Compiler) checkUndefinedFuncs() { for _, name := range c.sorted { m := c.Modules[name] - for _, err := range checkUndefinedFuncs(m, c.GetArity) { + for _, err := range checkUndefinedFuncs(m, c.GetArity, c.RewrittenVars) { c.err(err) } } } -func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors { +func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors { var errs Errors @@ -853,6 +853,7 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors { if arity(ref) >= 0 { return false } + ref = rewriteVarsInRef(rwVars)(ref) errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref)) return true }) @@ -2024,7 +2025,7 @@ func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error } func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) { - if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 { + if errs := checkUndefinedFuncs(body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 { return nil, errs } return body, nil @@ -4362,7 +4363,7 @@ func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{} return errs } -func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref { +func rewriteVarsInRef(vars ...map[Var]Var) varRewriter { return func(node Ref) Ref { i, _ := TransformVars(node, func(v Var) (Value, error) { for _, m := range vars { @@ -4375,7 +4376,3 @@ func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref { return i.(Ref) } } - -func rewriteVarsNop(node Ref) Ref { - return node -} diff --git a/ast/compile_test.go b/ast/compile_test.go index f4a583ed9c..a64176e5b4 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -1105,8 +1105,18 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) { deadbeef(x) } + # NOTE: all the dynamic dispatch examples here are not supported, + # we're checking assertions about the error returned. undefined_dynamic_dispatch { - x = "f"; data.test2[x](1) # not currently supported + x = "f"; data.test2[x](1) + } + + undefined_dynamic_dispatch_declared_var { + y := "f"; data.test2[y](1) + } + + undefined_dynamic_dispatch_declared_var_in_array { + z := "f"; data.test2[[z]](1) } ` @@ -1129,8 +1139,9 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) { "rego_type_error: undefined function data.deadbeef", "rego_type_error: undefined function deadbeef", "rego_type_error: undefined function data.test2[x]", + "rego_type_error: undefined function data.test2[y]", + "rego_type_error: undefined function data.test2[[z]]", } - for _, w := range want { if !strings.Contains(result, w) { t.Fatalf("Expected %q in result but got: %v", w, result) @@ -1138,6 +1149,28 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) { } } +func TestCompilerQueryCompilerCheckUndefinedFuncs(t *testing.T) { + compiler := NewCompiler() + + for _, tc := range []struct { + note, query, err string + }{ + + {note: "undefined function", query: `data.foo(1)`, err: "undefined function data.foo"}, + {note: "undefined global function", query: `foo(1)`, err: "undefined function foo"}, + {note: "var", query: `x = "f"; data[x](1)`, err: "undefined function data[x]"}, + {note: "declared var", query: `x := "f"; data[x](1)`, err: "undefined function data[x]"}, + {note: "declared var in array", query: `x := "f"; data[[x]](1)`, err: "undefined function data[[x]]"}, + } { + t.Run(tc.note, func(t *testing.T) { + _, err := compiler.QueryCompiler().Compile(MustParseBody(tc.query)) + if !strings.Contains(err.Error(), tc.err) { + t.Errorf("Unexpected compilation error: %v (want %s)", err, tc.err) + } + }) + } +} + func TestCompilerImportsResolved(t *testing.T) { modules := map[string]*Module{