Skip to content

Commit

Permalink
ast/compile: replace rewritten declared vars for undefined func error (
Browse files Browse the repository at this point in the history
…#4034)

Fixes #4031.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus authored Nov 19, 2021
1 parent 61f3bc3 commit 0a840de
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 15 deletions.
14 changes: 9 additions & 5 deletions ast/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/open-policy-agent/opa/util"
)

type rewriteVars func(x Ref) Ref
type varRewriter func(Ref) Ref

// exprChecker defines the interface for executing type checking on a single
// expression. The exprChecker must update the provided TypeEnv with inferred
Expand All @@ -26,7 +26,7 @@ type exprChecker func(*TypeEnv, *Expr) *Error
type typeChecker struct {
errs Errors
exprCheckers map[string]exprChecker
varRewriter rewriteVars
varRewriter varRewriter
ss *SchemaSet
allowNet []string
input types.Type
Expand Down Expand Up @@ -70,7 +70,7 @@ func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker {
return tc
}

func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker {
func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker {
tc.varRewriter = f
return tc
}
Expand Down Expand Up @@ -570,10 +570,14 @@ func (tc *typeChecker) err(errors []*Error) {
type refChecker struct {
env *TypeEnv
errs Errors
varRewriter rewriteVars
varRewriter varRewriter
}

func newRefChecker(env *TypeEnv, f rewriteVars) *refChecker {
func rewriteVarsNop(node Ref) Ref {
return node
}

func newRefChecker(env *TypeEnv, f varRewriter) *refChecker {

if f == nil {
f = rewriteVarsNop
Expand Down
13 changes: 5 additions & 8 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,13 +835,13 @@ func (c *Compiler) checkRuleConflicts() {
func (c *Compiler) checkUndefinedFuncs() {
for _, name := range c.sorted {
m := c.Modules[name]
for _, err := range checkUndefinedFuncs(m, c.GetArity) {
for _, err := range checkUndefinedFuncs(m, c.GetArity, c.RewrittenVars) {
c.err(err)
}
}
}

func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {

var errs Errors

Expand All @@ -853,6 +853,7 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
if arity(ref) >= 0 {
return false
}
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
return true
})
Expand Down Expand Up @@ -2024,7 +2025,7 @@ func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error
}

func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 {
if errs := checkUndefinedFuncs(body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
return nil, errs
}
return body, nil
Expand Down Expand Up @@ -4362,7 +4363,7 @@ func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}
return errs
}

func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
return func(node Ref) Ref {
i, _ := TransformVars(node, func(v Var) (Value, error) {
for _, m := range vars {
Expand All @@ -4375,7 +4376,3 @@ func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
return i.(Ref)
}
}

func rewriteVarsNop(node Ref) Ref {
return node
}
37 changes: 35 additions & 2 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1105,8 +1105,18 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
deadbeef(x)
}
# NOTE: all the dynamic dispatch examples here are not supported,
# we're checking assertions about the error returned.
undefined_dynamic_dispatch {
x = "f"; data.test2[x](1) # not currently supported
x = "f"; data.test2[x](1)
}
undefined_dynamic_dispatch_declared_var {
y := "f"; data.test2[y](1)
}
undefined_dynamic_dispatch_declared_var_in_array {
z := "f"; data.test2[[z]](1)
}
`

Expand All @@ -1129,15 +1139,38 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
"rego_type_error: undefined function data.deadbeef",
"rego_type_error: undefined function deadbeef",
"rego_type_error: undefined function data.test2[x]",
"rego_type_error: undefined function data.test2[y]",
"rego_type_error: undefined function data.test2[[z]]",
}

for _, w := range want {
if !strings.Contains(result, w) {
t.Fatalf("Expected %q in result but got: %v", w, result)
}
}
}

func TestCompilerQueryCompilerCheckUndefinedFuncs(t *testing.T) {
compiler := NewCompiler()

for _, tc := range []struct {
note, query, err string
}{

{note: "undefined function", query: `data.foo(1)`, err: "undefined function data.foo"},
{note: "undefined global function", query: `foo(1)`, err: "undefined function foo"},
{note: "var", query: `x = "f"; data[x](1)`, err: "undefined function data[x]"},
{note: "declared var", query: `x := "f"; data[x](1)`, err: "undefined function data[x]"},
{note: "declared var in array", query: `x := "f"; data[[x]](1)`, err: "undefined function data[[x]]"},
} {
t.Run(tc.note, func(t *testing.T) {
_, err := compiler.QueryCompiler().Compile(MustParseBody(tc.query))
if !strings.Contains(err.Error(), tc.err) {
t.Errorf("Unexpected compilation error: %v (want %s)", err, tc.err)
}
})
}
}

func TestCompilerImportsResolved(t *testing.T) {

modules := map[string]*Module{
Expand Down

0 comments on commit 0a840de

Please sign in to comment.