diff --git a/constraint/pkg/client/drivers/local/args.go b/constraint/pkg/client/drivers/local/args.go index 85b0d80b1..159d4dbeb 100644 --- a/constraint/pkg/client/drivers/local/args.go +++ b/constraint/pkg/client/drivers/local/args.go @@ -5,6 +5,7 @@ import ( "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" + "github.com/open-policy-agent/opa/topdown/print" opatypes "github.com/open-policy-agent/opa/types" ) @@ -45,6 +46,18 @@ func Tracing(enabled bool) Arg { } } +func PrintEnabled(enabled bool) Arg { + return func(d *driver) { + d.printEnabled = enabled + } +} + +func PrintHook(hook print.Hook) Arg { + return func(d *driver) { + d.printHook = hook + } +} + func Modules(modules map[string]*ast.Module) Arg { return func(d *driver) { d.modules = modules diff --git a/constraint/pkg/client/drivers/local/local.go b/constraint/pkg/client/drivers/local/local.go index da4a7d3af..d4dd40b3c 100644 --- a/constraint/pkg/client/drivers/local/local.go +++ b/constraint/pkg/client/drivers/local/local.go @@ -19,6 +19,7 @@ import ( "github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/topdown/print" opatypes "github.com/open-policy-agent/opa/types" "k8s.io/utils/pointer" ) @@ -67,6 +68,8 @@ type driver struct { storage storage.Store capabilities *ast.Capabilities traceEnabled bool + printEnabled bool + printHook print.Hook providerCache *externaldata.ProviderCache } @@ -243,7 +246,8 @@ func (d *driver) alterModules(insert insertParam, remove []string) (int, error) } c := ast.NewCompiler().WithPathConflictsCheck(storage.NonEmpty(ctx, d.storage, txn)). - WithCapabilities(d.capabilities) + WithCapabilities(d.capabilities). + WithEnablePrintStatements(d.printEnabled) if c.Compile(updatedModules); c.Failed() { d.storage.Abort(ctx, txn) @@ -391,6 +395,8 @@ func (d *driver) eval(ctx context.Context, path string, input interface{}, cfg * rego.Store(d.storage), rego.Input(input), rego.Query(path), + rego.EnablePrintStatements(d.printEnabled), + rego.PrintHook(d.printHook), } buf := topdown.NewBufferTracer() diff --git a/constraint/pkg/client/e2e_test.go b/constraint/pkg/client/e2e_test.go index 3aacc1891..244ae7813 100644 --- a/constraint/pkg/client/e2e_test.go +++ b/constraint/pkg/client/e2e_test.go @@ -3,6 +3,7 @@ package client import ( "context" "encoding/json" + "reflect" "strings" "testing" @@ -11,6 +12,7 @@ import ( "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" "github.com/open-policy-agent/frameworks/constraint/pkg/types" + "github.com/open-policy-agent/opa/topdown/print" "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -463,6 +465,91 @@ func TestE2ERemoveTemplate(t *testing.T) { } } +type appendingPrintHook struct { + printed *[]string +} + +func (a appendingPrintHook) Print(_ print.Context, s string) error { + *a.printed = append(*a.printed, s) + return nil +} + +func TestE2EPrint(t *testing.T) { + testCases := []struct { + name string + printEnabled bool + rego string + wantMsg string + wantPrint []string + }{{ + name: "Print enabled", + printEnabled: true, + rego: `package foo + violation[{"msg": "deny with print"}] { + print("denied!") + 1 == 1 + }`, + wantMsg: "deny with print", + wantPrint: []string{"denied!"}, + }, { + name: "Print disabled", + printEnabled: false, + rego: `package foo + violation[{"msg": "deny without print"}] { + print("denied!") + 1 == 1 + }`, + wantMsg: "deny without print", + wantPrint: []string{}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + var printed []string + printHook := appendingPrintHook{printed: &printed} + + d := local.New(local.PrintEnabled(tc.printEnabled), local.PrintHook(printHook)) + b, err := NewBackend(Driver(d)) + if err != nil { + t.Fatal(err) + } + + c, err := b.NewClient(Targets(&handler{})) + if err != nil { + t.Fatal(err) + } + + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego)) + if err != nil { + t.Fatalf("got AddTemplate: %v", err) + } + cstr := newConstraint("Foo", "ph", nil, nil) + if _, err := c.AddConstraint(ctx, cstr); err != nil { + t.Fatalf("got AddConstraint: %v", err) + } + + rsps, err := c.Review(ctx, targetData{Name: "Hanna", ForConstraint: "Foo"}) + if err != nil { + t.Fatalf("got Review: %v", err) + } + + results := rsps.Results() + if len(results) != 1 { + t.Errorf("expected 1 result, got %v", len(results)) + } + if results[0].Msg != tc.wantMsg { + t.Errorf("expected msg %v, got %v", tc.wantMsg, results[0].Msg) + } + + if len(tc.wantPrint)+len(printed) > 0 && !reflect.DeepEqual(tc.wantPrint, printed) { + t.Errorf("Wanted %v printed, got %v", tc.wantPrint, printed) + } + }) + } +} + func TestE2ETracingOff(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) {