From 0a840de5b70ae8f9029ad163d13934b9285dda91 Mon Sep 17 00:00:00 2001
From: Stephan Renatus <stephan.renatus@gmail.com>
Date: Fri, 19 Nov 2021 10:05:17 +0100
Subject: [PATCH] ast/compile: replace rewritten declared vars for undefined
 func error (#4034)

Fixes #4031.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
---
 ast/check.go        | 14 +++++++++-----
 ast/compile.go      | 13 +++++--------
 ast/compile_test.go | 37 +++++++++++++++++++++++++++++++++++--
 3 files changed, 49 insertions(+), 15 deletions(-)

diff --git a/ast/check.go b/ast/check.go
index 6145266599..1e0971bc89 100644
--- a/ast/check.go
+++ b/ast/check.go
@@ -13,7 +13,7 @@ import (
 	"github.com/open-policy-agent/opa/util"
 )
 
-type rewriteVars func(x Ref) Ref
+type varRewriter func(Ref) Ref
 
 // exprChecker defines the interface for executing type checking on a single
 // expression. The exprChecker must update the provided TypeEnv with inferred
@@ -26,7 +26,7 @@ type exprChecker func(*TypeEnv, *Expr) *Error
 type typeChecker struct {
 	errs         Errors
 	exprCheckers map[string]exprChecker
-	varRewriter  rewriteVars
+	varRewriter  varRewriter
 	ss           *SchemaSet
 	allowNet     []string
 	input        types.Type
@@ -70,7 +70,7 @@ func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker {
 	return tc
 }
 
-func (tc *typeChecker) WithVarRewriter(f rewriteVars) *typeChecker {
+func (tc *typeChecker) WithVarRewriter(f varRewriter) *typeChecker {
 	tc.varRewriter = f
 	return tc
 }
@@ -570,10 +570,14 @@ func (tc *typeChecker) err(errors []*Error) {
 type refChecker struct {
 	env         *TypeEnv
 	errs        Errors
-	varRewriter rewriteVars
+	varRewriter varRewriter
 }
 
-func newRefChecker(env *TypeEnv, f rewriteVars) *refChecker {
+func rewriteVarsNop(node Ref) Ref {
+	return node
+}
+
+func newRefChecker(env *TypeEnv, f varRewriter) *refChecker {
 
 	if f == nil {
 		f = rewriteVarsNop
diff --git a/ast/compile.go b/ast/compile.go
index 37d044d041..eddfbd21c2 100644
--- a/ast/compile.go
+++ b/ast/compile.go
@@ -835,13 +835,13 @@ func (c *Compiler) checkRuleConflicts() {
 func (c *Compiler) checkUndefinedFuncs() {
 	for _, name := range c.sorted {
 		m := c.Modules[name]
-		for _, err := range checkUndefinedFuncs(m, c.GetArity) {
+		for _, err := range checkUndefinedFuncs(m, c.GetArity, c.RewrittenVars) {
 			c.err(err)
 		}
 	}
 }
 
-func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
+func checkUndefinedFuncs(x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {
 
 	var errs Errors
 
@@ -853,6 +853,7 @@ func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
 		if arity(ref) >= 0 {
 			return false
 		}
+		ref = rewriteVarsInRef(rwVars)(ref)
 		errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
 		return true
 	})
@@ -2024,7 +2025,7 @@ func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error
 }
 
 func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
-	if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 {
+	if errs := checkUndefinedFuncs(body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
 		return nil, errs
 	}
 	return body, nil
@@ -4362,7 +4363,7 @@ func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}
 	return errs
 }
 
-func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
+func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
 	return func(node Ref) Ref {
 		i, _ := TransformVars(node, func(v Var) (Value, error) {
 			for _, m := range vars {
@@ -4375,7 +4376,3 @@ func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
 		return i.(Ref)
 	}
 }
-
-func rewriteVarsNop(node Ref) Ref {
-	return node
-}
diff --git a/ast/compile_test.go b/ast/compile_test.go
index f4a583ed9c..a64176e5b4 100644
--- a/ast/compile_test.go
+++ b/ast/compile_test.go
@@ -1105,8 +1105,18 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
 			deadbeef(x)
 		}
 
+		# NOTE: all the dynamic dispatch examples here are not supported,
+		#       we're checking assertions about the error returned.
 		undefined_dynamic_dispatch {
-			x = "f"; data.test2[x](1)  # not currently supported
+			x = "f"; data.test2[x](1)
+		}
+
+		undefined_dynamic_dispatch_declared_var {
+			y := "f"; data.test2[y](1)
+		}
+
+		undefined_dynamic_dispatch_declared_var_in_array {
+			z := "f"; data.test2[[z]](1)
 		}
 	`
 
@@ -1129,8 +1139,9 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
 		"rego_type_error: undefined function data.deadbeef",
 		"rego_type_error: undefined function deadbeef",
 		"rego_type_error: undefined function data.test2[x]",
+		"rego_type_error: undefined function data.test2[y]",
+		"rego_type_error: undefined function data.test2[[z]]",
 	}
-
 	for _, w := range want {
 		if !strings.Contains(result, w) {
 			t.Fatalf("Expected %q in result but got: %v", w, result)
@@ -1138,6 +1149,28 @@ func TestCompilerCheckUndefinedFuncs(t *testing.T) {
 	}
 }
 
+func TestCompilerQueryCompilerCheckUndefinedFuncs(t *testing.T) {
+	compiler := NewCompiler()
+
+	for _, tc := range []struct {
+		note, query, err string
+	}{
+
+		{note: "undefined function", query: `data.foo(1)`, err: "undefined function data.foo"},
+		{note: "undefined global function", query: `foo(1)`, err: "undefined function foo"},
+		{note: "var", query: `x = "f"; data[x](1)`, err: "undefined function data[x]"},
+		{note: "declared var", query: `x := "f"; data[x](1)`, err: "undefined function data[x]"},
+		{note: "declared var in array", query: `x := "f"; data[[x]](1)`, err: "undefined function data[[x]]"},
+	} {
+		t.Run(tc.note, func(t *testing.T) {
+			_, err := compiler.QueryCompiler().Compile(MustParseBody(tc.query))
+			if !strings.Contains(err.Error(), tc.err) {
+				t.Errorf("Unexpected compilation error: %v (want  %s)", err, tc.err)
+			}
+		})
+	}
+}
+
 func TestCompilerImportsResolved(t *testing.T) {
 
 	modules := map[string]*Module{