From cd6826e2d8a71d5e7aa10872ce75739a284520d8 Mon Sep 17 00:00:00 2001 From: Will Beason Date: Wed, 1 Dec 2021 09:24:09 -0800 Subject: [PATCH 1/3] Use context consistently Eliminate all uses of context.TODO(). Eliminate all uses of context.Background() in production code. In all cases, either our callers have a Context they can pass or the k8s framework provides one. This change breaks gatekeeper - since this modifies our external APIs, making this in a non-breaking way would require 5 PRs instead of 2. Once this PR has been submitted, I'll submit a PR which updates the version of frameworks/ that gatekeeper uses and fix the resulting breakages. Signed-off-by: Will Beason --- .../v1/constrainttemplate_types_test.go | 14 +- .../v1alpha1/constrainttemplate_types_test.go | 14 +- .../v1beta1/constrainttemplate_types_test.go | 14 +- constraint/pkg/client/backend.go | 6 +- constraint/pkg/client/client.go | 2 +- .../client/client_addtemplate_bench_test.go | 2 +- constraint/pkg/client/client_test.go | 1372 +++++++---------- .../pkg/client/drivers/local/local_test.go | 35 +- .../pkg/client/drivers/remote/remote_test.go | 4 +- constraint/pkg/client/e2e_test.go | 69 +- constraint/pkg/client/regolib/rego_test.go | 4 +- 11 files changed, 642 insertions(+), 894 deletions(-) diff --git a/constraint/pkg/apis/templates/v1/constrainttemplate_types_test.go b/constraint/pkg/apis/templates/v1/constrainttemplate_types_test.go index 7cd687ae3..0fee7a221 100644 --- a/constraint/pkg/apis/templates/v1/constrainttemplate_types_test.go +++ b/constraint/pkg/apis/templates/v1/constrainttemplate_types_test.go @@ -33,6 +33,8 @@ import ( ) func TestStorageConstraintTemplate(t *testing.T) { + ctx := context.Background() + key := types.NamespacedName{ Name: "foo", } @@ -45,22 +47,22 @@ func TestStorageConstraintTemplate(t *testing.T) { // Test Create fetched := &ConstraintTemplate{} - g.Expect(c.Create(context.TODO(), created)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Create(ctx, created)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(created)) // Test Updating the Labels updated := fetched.DeepCopy() updated.Labels = map[string]string{"hello": "world"} - g.Expect(c.Update(context.TODO(), updated)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Update(ctx, updated)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(updated)) // Test Delete - g.Expect(c.Delete(context.TODO(), fetched)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).To(gomega.HaveOccurred()) + g.Expect(c.Delete(ctx, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).To(gomega.HaveOccurred()) } func TestTypeConversion(t *testing.T) { diff --git a/constraint/pkg/apis/templates/v1alpha1/constrainttemplate_types_test.go b/constraint/pkg/apis/templates/v1alpha1/constrainttemplate_types_test.go index 141af5164..f5d374b4f 100644 --- a/constraint/pkg/apis/templates/v1alpha1/constrainttemplate_types_test.go +++ b/constraint/pkg/apis/templates/v1alpha1/constrainttemplate_types_test.go @@ -33,6 +33,8 @@ import ( ) func TestStorageConstraintTemplate(t *testing.T) { + ctx := context.Background() + key := types.NamespacedName{ Name: "foo", } @@ -45,22 +47,22 @@ func TestStorageConstraintTemplate(t *testing.T) { // Test Create fetched := &ConstraintTemplate{} - g.Expect(c.Create(context.TODO(), created)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Create(ctx, created)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(created)) // Test Updating the Labels updated := fetched.DeepCopy() updated.Labels = map[string]string{"hello": "world"} - g.Expect(c.Update(context.TODO(), updated)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Update(ctx, updated)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(updated)) // Test Delete - g.Expect(c.Delete(context.TODO(), fetched)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).To(gomega.HaveOccurred()) + g.Expect(c.Delete(ctx, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).To(gomega.HaveOccurred()) } func TestTypeConversion(t *testing.T) { diff --git a/constraint/pkg/apis/templates/v1beta1/constrainttemplate_types_test.go b/constraint/pkg/apis/templates/v1beta1/constrainttemplate_types_test.go index afebb2ff5..53687c055 100644 --- a/constraint/pkg/apis/templates/v1beta1/constrainttemplate_types_test.go +++ b/constraint/pkg/apis/templates/v1beta1/constrainttemplate_types_test.go @@ -33,6 +33,8 @@ import ( ) func TestStorageConstraintTemplate(t *testing.T) { + ctx := context.Background() + key := types.NamespacedName{ Name: "foo", } @@ -45,22 +47,22 @@ func TestStorageConstraintTemplate(t *testing.T) { // Test Create fetched := &ConstraintTemplate{} - g.Expect(c.Create(context.TODO(), created)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Create(ctx, created)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(created)) // Test Updating the Labels updated := fetched.DeepCopy() updated.Labels = map[string]string{"hello": "world"} - g.Expect(c.Update(context.TODO(), updated)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Update(ctx, updated)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).NotTo(gomega.HaveOccurred()) g.Expect(fetched).To(gomega.Equal(updated)) // Test Delete - g.Expect(c.Delete(context.TODO(), fetched)).NotTo(gomega.HaveOccurred()) - g.Expect(c.Get(context.TODO(), key, fetched)).To(gomega.HaveOccurred()) + g.Expect(c.Delete(ctx, fetched)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Get(ctx, key, fetched)).To(gomega.HaveOccurred()) } func TestTypeConversion(t *testing.T) { diff --git a/constraint/pkg/client/backend.go b/constraint/pkg/client/backend.go index 586b29aa8..4a58d165c 100644 --- a/constraint/pkg/client/backend.go +++ b/constraint/pkg/client/backend.go @@ -45,7 +45,7 @@ func NewBackend(opts ...BackendOpt) (*Backend, error) { } // NewClient creates a new client for the supplied backend. -func (b *Backend) NewClient(opts ...Opt) (*Client, error) { +func (b *Backend) NewClient(ctx context.Context, opts ...Opt) (*Client, error) { if b.hasClient { return nil, fmt.Errorf("%w: only one client per backend is allowed", ErrCreatingClient) @@ -81,11 +81,11 @@ func (b *Backend) NewClient(opts ...Opt) (*Client, error) { ErrCreatingClient) } - if err := b.driver.Init(context.Background()); err != nil { + if err := b.driver.Init(ctx); err != nil { return nil, err } - if err := c.init(); err != nil { + if err := c.init(ctx); err != nil { return nil, err } diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index e37ae1b51..4e4fd2750 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -684,7 +684,7 @@ func (c *Client) ValidateConstraint(_ context.Context, constraint *unstructured. } // init initializes the OPA backend for the client. -func (c *Client) init() error { +func (c *Client) init(ctx context.Context) error { for _, t := range c.targets { hooks := fmt.Sprintf(`hooks["%s"]`, t.GetName()) templMap := map[string]string{"Target": t.GetName()} diff --git a/constraint/pkg/client/client_addtemplate_bench_test.go b/constraint/pkg/client/client_addtemplate_bench_test.go index 025aa5fbf..ad7f7028d 100644 --- a/constraint/pkg/client/client_addtemplate_bench_test.go +++ b/constraint/pkg/client/client_addtemplate_bench_test.go @@ -92,7 +92,7 @@ func BenchmarkClient_AddTemplate(b *testing.B) { b.Fatal(err) } - c, err := backend.NewClient(targets) + c, err := backend.NewClient(ctx, targets) if err != nil { b.Fatal(err) } diff --git a/constraint/pkg/client/client_test.go b/constraint/pkg/client/client_test.go index 86c9f942a..5ff15366e 100644 --- a/constraint/pkg/client/client_test.go +++ b/constraint/pkg/client/client_test.go @@ -3,19 +3,17 @@ package client import ( "context" "errors" + "reflect" "strings" "testing" "text/template" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" + constraintlib "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" "github.com/open-policy-agent/frameworks/constraint/pkg/types" "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - "k8s.io/utils/pointer" ) const badRego = `asd{` @@ -45,10 +43,11 @@ matching_reviews_and_constraints[[r,c]] {r = data.r; c = data.c}`)) } func (h *badHandler) MatchSchema() apiextensions.JSONSchemaProps { - return apiextensions.JSONSchemaProps{XPreserveUnknownFields: pointer.Bool(true)} + trueBool := true + return apiextensions.JSONSchemaProps{XPreserveUnknownFields: &trueBool} } -func (h *badHandler) ProcessData(_ interface{}) (bool, string, interface{}, error) { +func (h *badHandler) ProcessData(obj interface{}) (bool, string, interface{}, error) { if h.Errors { return false, "", nil, errors.New("some error") } @@ -58,241 +57,245 @@ func (h *badHandler) ProcessData(_ interface{}) (bool, string, interface{}, erro return true, "projects/something", nil, nil } -func (h *badHandler) HandleReview(_ interface{}) (bool, interface{}, error) { +func (h *badHandler) HandleReview(obj interface{}) (bool, interface{}, error) { return false, "", nil } -func (h *badHandler) HandleViolation(_ *types.Result) error { +func (h *badHandler) HandleViolation(result *types.Result) error { return nil } -func (h *badHandler) ValidateConstraint(_ *unstructured.Unstructured) error { +func (h *badHandler) ValidateConstraint(u *unstructured.Unstructured) error { return nil } func TestInvalidTargetName(t *testing.T) { - tcs := []struct { - name string - handler TargetHandler - wantError error + tc := []struct { + Name string + Handler TargetHandler + ErrorExpected bool }{ { - name: "Acceptable name", - handler: &badHandler{Name: "Hello8", HasLib: true}, - wantError: nil, + Name: "Acceptable Name", + Handler: &badHandler{Name: "Hello8", HasLib: true}, + ErrorExpected: false, }, { - name: "No name", - handler: &badHandler{Name: ""}, - wantError: ErrCreatingClient, + Name: "No Name", + Handler: &badHandler{Name: ""}, + ErrorExpected: true, }, { - name: "Dots not allowed", - handler: &badHandler{Name: "asdf.asdf"}, - wantError: ErrCreatingClient, + Name: "No Dots", + Handler: &badHandler{Name: "asdf.asdf"}, + ErrorExpected: true, }, { - name: "Spaces not allowed", - handler: &badHandler{Name: "asdf asdf"}, - wantError: ErrCreatingClient, + Name: "No Spaces", + Handler: &badHandler{Name: "asdf asdf"}, + ErrorExpected: true, }, { - name: "Must start with a letter", - handler: &badHandler{Name: "8asdf"}, - wantError: ErrCreatingClient, + Name: "Must start with a letter", + Handler: &badHandler{Name: "8asdf"}, + ErrorExpected: true, }, } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - - _, err = b.NewClient(Targets(tc.handler)) - if !errors.Is(err, tc.wantError) { - t.Errorf("got NewClient() error = %v, want %v", - err, tc.wantError) + _, err = b.NewClient(Targets(tt.Handler)) + if (err == nil) && tt.ErrorExpected { + t.Fatalf("err = nil; want non-nil") + } + if (err != nil) && !tt.ErrorExpected { + t.Fatalf("err = \"%s\"; want nil", err) } }) } } func TestAddData(t *testing.T) { - tcs := []struct { - name string - handler1 TargetHandler - handler2 TargetHandler - wantHandled map[string]bool - wantError map[string]bool + tc := []struct { + Name string + Handler1 TargetHandler + Handler2 TargetHandler + ErroredBy []string + HandledBy []string }{ { - name: "Handled By Both", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, - wantHandled: map[string]bool{"h1": true, "h2": true}, - wantError: nil, + Name: "Handled By Both", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, + HandledBy: []string{"h1", "h2"}, }, { - name: "Handled By One", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Handled By One", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + HandledBy: []string{"h1"}, }, { - name: "Errored By One", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - wantHandled: map[string]bool{"h1": true}, - wantError: map[string]bool{"h2": true}, + Name: "Errored By One", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + HandledBy: []string{"h1"}, + ErroredBy: []string{"h2"}, }, { - name: "Errored By Both", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - wantError: map[string]bool{"h1": true, "h2": true}, + Name: "Errored By Both", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + ErroredBy: []string{"h1", "h2"}, }, { - name: "Handled By None", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - wantHandled: nil, - wantError: nil, + Name: "Handled By None", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, }, } + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler1, tc.handler2)) + c, err := b.NewClient(Targets(tt.Handler1, tt.Handler2)) if err != nil { t.Fatal(err) } - - r, err := c.AddData(context.Background(), nil) - if err != nil && len(tc.wantError) == 0 { + r, err := c.AddData(ctx, nil) + if err != nil && len(tt.ErroredBy) == 0 { t.Fatalf("err = %s; want nil", err) } - gotErrs := make(map[string]bool) - if e, ok := err.(*ErrorMap); ok { - for k := range *e { - gotErrs[k] = true + expectedErr := make(map[string]bool) + actualErr := make(map[string]bool) + for _, v := range tt.ErroredBy { + expectedErr[v] = true + } + if e, ok := err.(ErrorMap); ok { + for k := range e { + actualErr[k] = true } } - - if diff := cmp.Diff(tc.wantError, gotErrs, cmpopts.EquateEmpty()); diff != "" { - t.Errorf(diff) + if !reflect.DeepEqual(actualErr, expectedErr) { + t.Errorf("errSet = %v; wanted %v", actualErr, expectedErr) + } + expectedHandled := make(map[string]bool) + for _, v := range tt.HandledBy { + expectedHandled[v] = true } - if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("handledSet = %v; wanted %v", r.Handled, expectedHandled) + } + if r.HandledCount() != len(expectedHandled) { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), len(expectedHandled)) } }) } } func TestRemoveData(t *testing.T) { - tcs := []struct { - name string - handler1 TargetHandler - handler2 TargetHandler - wantHandled map[string]bool - wantError map[string]bool + tc := []struct { + Name string + Handler1 TargetHandler + Handler2 TargetHandler + ErroredBy []string + HandledBy []string }{ { - name: "Handled By Both", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, - wantHandled: map[string]bool{"h1": true, "h2": true}, - wantError: nil, + Name: "Handled By Both", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, + HandledBy: []string{"h1", "h2"}, }, { - name: "Handled By One", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Handled By One", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + HandledBy: []string{"h1"}, }, { - name: "Errored By One", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - wantHandled: map[string]bool{"h1": true}, - wantError: map[string]bool{"h2": true}, + Name: "Errored By One", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + HandledBy: []string{"h1"}, + ErroredBy: []string{"h2"}, }, { - name: "Errored By Both", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - wantHandled: nil, - wantError: map[string]bool{"h1": true, "h2": true}, + Name: "Errored By Both", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + ErroredBy: []string{"h1", "h2"}, }, { - name: "Handled By None", - handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, - handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - wantHandled: nil, - wantError: nil, + Name: "Handled By None", + Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, + Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, }, } + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler1, tc.handler2)) + c, err := b.NewClient(Targets(tt.Handler1, tt.Handler2)) if err != nil { t.Fatal(err) } - - r, err := c.RemoveData(context.Background(), nil) - if err != nil && len(tc.wantError) == 0 { + r, err := c.RemoveData(ctx, nil) + if err != nil && len(tt.ErroredBy) == 0 { t.Fatalf("err = %s; want nil", err) } - - gotErrs := make(map[string]bool) - if e, ok := err.(*ErrorMap); ok { - for k := range *e { - gotErrs[k] = true + expectedErr := make(map[string]bool) + actualErr := make(map[string]bool) + for _, v := range tt.ErroredBy { + expectedErr[v] = true + } + if e, ok := err.(ErrorMap); ok { + for k := range e { + actualErr[k] = true } } - - if diff := cmp.Diff(tc.wantError, gotErrs, cmpopts.EquateEmpty()); diff != "" { - t.Errorf(diff) + if !reflect.DeepEqual(actualErr, expectedErr) { + t.Errorf("errSet = %v; wanted %v", actualErr, expectedErr) + } + expectedHandled := make(map[string]bool) + for _, v := range tt.HandledBy { + expectedHandled[v] = true } if r == nil { t.Fatal("got RemoveData() == nil, want non-nil") } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("handledSet = %v; wanted %v", r.Handled, expectedHandled) + } + if r.HandledCount() != len(expectedHandled) { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), len(expectedHandled)) } }) } } -func TestClient_AddTemplate(t *testing.T) { +func TestAddTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego @@ -307,97 +310,100 @@ some_rule[r] { emptyRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) emptyRegoTempl.Spec.Targets[0].Rego = "" - tcs := []struct { - name string - handler TargetHandler - template *templates.ConstraintTemplate - wantHandled map[string]bool - wantError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool }{ { - name: "Good Template", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Good Template", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), + ErrorExpected: false, }, { - name: "Unknown Target", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Unknown Target", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + ErrorExpected: true, }, { - name: "Bad CRD", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fakes"), targets("h1")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Bad CRD", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fakes"), targets("h1")), + ErrorExpected: true, }, { - name: "No name", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(crdNames("Fake"), targets("h1")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "No Name", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(crdNames("Fake"), targets("h1")), + ErrorExpected: true, }, { - name: "Bad Rego", - handler: &badHandler{Name: "h1", HasLib: true}, - template: badRegoTempl, - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Bad Rego", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: badRegoTempl, + ErrorExpected: true, }, { - name: "No Rego", - handler: &badHandler{Name: "h1", HasLib: true}, - template: emptyRegoTempl, - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "No Rego", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: emptyRegoTempl, + ErrorExpected: true, }, { - name: "Missing Rule", - handler: &badHandler{Name: "h1", HasLib: true}, - template: missingRuleTempl, - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Missing Rule", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: missingRuleTempl, + ErrorExpected: true, }, } + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler)) + c, err := b.NewClient(Targets(tt.Handler)) if err != nil { t.Fatal(err) } - r, err := c.AddTemplate(context.Background(), tc.template) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantError) + r, err := c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Fatalf("err = %v; want nil", err) + } + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} + } if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + } + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + } - if diff := cmp.Diff(r.Handled, tc.wantHandled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + cached, err := c.GetTemplate(tt.Template) + if err == nil && tt.ErrorExpected { + t.Fatal("retrieved template when error was expected") } - cached, err := c.GetTemplate(context.Background(), tc.template) - if tc.wantError != nil { - if err == nil { - t.Fatalf("got GetTemplate() error = %v, want non-nil", err) - } + if tt.ErrorExpected { return } @@ -405,20 +411,17 @@ some_rule[r] { t.Fatalf("could not retrieve template when error was expected: %v", err) } - if !cached.SemanticEqual(tc.template) { + if !cached.SemanticEqual(tt.Template) { t.Error("cached template does not equal stored template") } - - r2, err := c.RemoveTemplate(context.Background(), tc.template) + r2, err := c.RemoveTemplate(ctx, tt.Template) if err != nil { t.Fatal("could not remove template") } - if r2.HandledCount() != 1 { t.Error("more targets handled than expected") } - - if _, err := c.GetTemplate(context.Background(), tc.template); err == nil { + if _, err := c.GetTemplate(tt.Template); err == nil { t.Error("template not cleared from cache") } }) @@ -428,62 +431,67 @@ some_rule[r] { func TestRemoveTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tcs := []struct { - name string - handler TargetHandler - template *templates.ConstraintTemplate - wantHandled map[string]bool - wantError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool }{ { - name: "Good Template", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Good Template", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + ErrorExpected: false, }, { - name: "Unknown Target", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Unknown Target", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + ErrorExpected: true, }, { - name: "Bad CRD", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), targets("h1")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Bad CRD", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), targets("h1")), + ErrorExpected: true, }, } - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - d := local.New() + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() + d := local.New() b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler)) + c, err := b.NewClient(Targets(tt.Handler)) if err != nil { t.Fatal(err) } - - _, err = c.AddTemplate(context.Background(), tc.template) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantError) + _, err = c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - r, err := c.RemoveTemplate(context.Background(), tc.template) + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") + } + r, err := c.RemoveTemplate(ctx, tt.Template) if err != nil { t.Errorf("err = %v; want nil", err) } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} + } + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + } + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) } }) } @@ -492,66 +500,69 @@ func TestRemoveTemplate(t *testing.T) { func TestRemoveTemplateByNameOnly(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tcs := []struct { - name string - handler TargetHandler - template *templates.ConstraintTemplate - wantHandled map[string]bool - wantError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool }{ { - name: "Good Template", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Good Template", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + ErrorExpected: false, }, { - name: "Unknown Target", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Unknown Target", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + ErrorExpected: true, }, { - name: "Bad CRD", - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fake"), targets("h1")), - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Bad CRD", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), targets("h1")), + ErrorExpected: true, }, } + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler)) + c, err := b.NewClient(Targets(tt.Handler)) if err != nil { t.Fatal(err) } - - _, err = c.AddTemplate(context.Background(), tc.template) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantError) + _, err = c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) + } + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } - sparseTemplate := &templates.ConstraintTemplate{} - sparseTemplate.Name = tc.template.Name - - r, err := c.RemoveTemplate(context.Background(), sparseTemplate) + sparseTemplate.Name = tt.Template.Name + r, err := c.RemoveTemplate(ctx, sparseTemplate) if err != nil { t.Errorf("err = %v; want nil", err) } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} + } + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + } + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) } }) } @@ -560,69 +571,61 @@ func TestRemoveTemplateByNameOnly(t *testing.T) { func TestGetTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - - tcs := []struct { - name string - handler TargetHandler - wantTemplate *templates.ConstraintTemplate - wantAddError error - wantGetError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool }{ { - name: "Good Template", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - wantAddError: nil, - wantGetError: nil, + Name: "Good Template", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + ErrorExpected: false, }, { - name: "Unknown Target", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - wantAddError: ErrInvalidConstraintTemplate, - wantGetError: ErrMissingConstraintTemplate, + Name: "Unknown Target", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + ErrorExpected: true, }, { - name: "Bad CRD", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), targets("h1")), - wantAddError: ErrInvalidConstraintTemplate, - wantGetError: ErrMissingConstraintTemplate, + Name: "Bad CRD", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), targets("h1")), + ErrorExpected: true, }, } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler)) + c, err := b.NewClient(Targets(tt.Handler)) if err != nil { t.Fatal(err) } - - _, err = c.AddTemplate(context.Background(), tc.wantTemplate) - if !errors.Is(err, tc.wantAddError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantAddError) + _, err = c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - gotTemplate, err := c.GetTemplate(context.Background(), tc.wantTemplate) - if !errors.Is(err, tc.wantGetError) { - t.Fatalf("got GetTemplate() error = %v, want %v", - err, tc.wantGetError) + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } - - if tc.wantAddError != nil { - return + tmpl, err := c.GetTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - if diff := cmp.Diff(tc.wantTemplate, gotTemplate); diff != "" { - t.Error(diff) + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") + } + if !tt.ErrorExpected { + if !reflect.DeepEqual(tmpl, tt.Template) { + t.Error("Stored and retrieved template differ") + } } }) } @@ -631,78 +634,71 @@ func TestGetTemplate(t *testing.T) { func TestGetTemplateByNameOnly(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - - tcs := []struct { - name string - handler TargetHandler - wantTemplate *templates.ConstraintTemplate - wantAddError error - wantGetError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool }{ { - name: "Good Template", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - wantAddError: nil, - wantGetError: nil, + Name: "Good Template", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + ErrorExpected: false, }, { - name: "Unknown Target", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - wantAddError: ErrInvalidConstraintTemplate, - wantGetError: ErrMissingConstraintTemplate, + Name: "Unknown Target", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + ErrorExpected: true, }, { - name: "Bad CRD", - handler: &badHandler{Name: "h1", HasLib: true}, - wantTemplate: createTemplate(name("fake"), targets("h1")), - wantAddError: ErrInvalidConstraintTemplate, - wantGetError: ErrMissingConstraintTemplate, + Name: "Bad CRD", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fake"), targets("h1")), + ErrorExpected: true, }, } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tc.handler)) + c, err := b.NewClient(Targets(tt.Handler)) if err != nil { t.Fatal(err) } - - _, err = c.AddTemplate(context.Background(), tc.wantTemplate) - if !errors.Is(err, tc.wantAddError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantAddError) + _, err = c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) + } + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } - sparseTemplate := &templates.ConstraintTemplate{} - sparseTemplate.Name = tc.wantTemplate.Name - - gotTemplate, err := c.GetTemplate(context.Background(), sparseTemplate) - if !errors.Is(err, tc.wantGetError) { - t.Fatalf("Got GetTemplate() error = %v, want %v", - err, tc.wantGetError) + sparseTemplate.Name = tt.Template.Name + tmpl, err := c.GetTemplate(sparseTemplate) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - if tc.wantGetError != nil { - return + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } - - if diff := cmp.Diff(tc.wantTemplate, gotTemplate); diff != "" { - t.Error(diff) + if !tt.ErrorExpected { + if !reflect.DeepEqual(tmpl, tt.Template) { + t.Error("Stored and retrieved template differ") + } } }) } } func TestTemplateCascadingDelete(t *testing.T) { + ctx := context.Background() + handler := &badHandler{Name: "h1", HasLib: true} d := local.New() @@ -715,38 +711,35 @@ func TestTemplateCascadingDelete(t *testing.T) { if err != nil { t.Fatal(err) } - templ := createTemplate(name("cascadingdelete"), crdNames("CascadingDelete"), targets("h1")) - if _, err = c.AddTemplate(context.Background(), templ); err != nil { + if _, err = c.AddTemplate(templ); err != nil { t.Errorf("err = %v; want nil", err) } cst1 := newConstraint("CascadingDelete", "cascadingdelete", nil, nil) - if _, err = c.AddConstraint(context.Background(), cst1); err != nil { + if _, err = c.AddConstraint(ctx, cst1); err != nil { t.Error("could not add first constraint") } - cst2 := newConstraint("CascadingDelete", "cascadingdelete2", nil, nil) - if _, err = c.AddConstraint(context.Background(), cst2); err != nil { + if _, err = c.AddConstraint(ctx, cst2); err != nil { t.Error("could not add second constraint") } template2 := createTemplate(name("stillpersists"), crdNames("StillPersists"), targets("h1")) - if _, err = c.AddTemplate(context.Background(), template2); err != nil { + if _, err = c.AddTemplate(template2); err != nil { t.Errorf("err = %v; want nil", err) } cst3 := newConstraint("StillPersists", "stillpersists", nil, nil) - if _, err = c.AddConstraint(context.Background(), cst3); err != nil { + if _, err = c.AddConstraint(ctx, cst3); err != nil { t.Error("could not add third constraint") } - cst4 := newConstraint("StillPersists", "stillpersists2", nil, nil) - if _, err = c.AddConstraint(context.Background(), cst4); err != nil { + if _, err = c.AddConstraint(ctx, cst4); err != nil { t.Error("could not add fourth constraint") } - orig, err := c.Dump(context.Background()) + orig, err := c.Dump(ctx) if err != nil { t.Errorf("could not dump original state: %s", err) } @@ -762,24 +755,21 @@ func TestTemplateCascadingDelete(t *testing.T) { t.Errorf("preservation candidate not cached: %s", orig) } - if _, err = c.RemoveTemplate(context.Background(), templ); err != nil { + if _, err = c.RemoveTemplate(ctx, templ); err != nil { t.Error("could not remove template") } - if len(c.constraints) != 1 { t.Errorf("constraint cache expected to have only 1 entry: %+v", c.constraints) } - s, err := c.Dump(context.Background()) + s, err := c.Dump(ctx) if err != nil { t.Errorf("could not dump OPA cache") } - sLower := strings.ToLower(s) if strings.Contains(sLower, "cascadingdelete") { t.Errorf("Template not removed from cache: %s", s) } - finalPreserved := strings.Count(sLower, "stillpersists") if finalPreserved != origPreserved { t.Errorf("finalPreserved = %d, expected %d :: %s", finalPreserved, origPreserved, s) @@ -787,169 +777,146 @@ func TestTemplateCascadingDelete(t *testing.T) { } func TestAddConstraint(t *testing.T) { - handler := &badHandler{Name: "h1", HasLib: true} - - tcs := []struct { - name string - template *templates.ConstraintTemplate - constraint *unstructured.Unstructured - wantHandled map[string]bool - wantAddConstraintError error - wantGetConstraintError error + tc := []struct { + Name string + Constraint *unstructured.Unstructured + OmitTemplate bool + ErrorExpected bool }{ { - name: "Good Constraint", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("Foos", "foo", nil, nil), - wantHandled: map[string]bool{"h1": true}, - wantAddConstraintError: nil, - wantGetConstraintError: nil, + Name: "Good Constraint", + Constraint: newConstraint("Foos", "foo", nil, nil), }, { - name: "No Name", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("Foos", "", nil, nil), - wantHandled: nil, - wantAddConstraintError: ErrInvalidConstraint, - wantGetConstraintError: ErrInvalidConstraint, + Name: "No Name", + Constraint: newConstraint("Foos", "", nil, nil), + ErrorExpected: true, }, { - name: "No Kind", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("", "foo", nil, nil), - wantHandled: nil, - wantAddConstraintError: ErrInvalidConstraint, - wantGetConstraintError: ErrInvalidConstraint, + Name: "No Kind", + Constraint: newConstraint("", "foo", nil, nil), + ErrorExpected: true, }, { - name: "No Template", - template: nil, - constraint: newConstraint("Foo", "foo", nil, nil), - wantHandled: nil, - wantAddConstraintError: ErrMissingConstraintTemplate, - wantGetConstraintError: ErrMissingConstraint, + Name: "No Template", + Constraint: newConstraint("Foo", "foo", nil, nil), + OmitTemplate: true, + ErrorExpected: true, }, } + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { - t.Fatal(err) + t.Fatalf("Could not create backend: %s", err) } + handler := &badHandler{Name: "h1", HasLib: true} c, err := b.NewClient(Targets(handler)) if err != nil { t.Fatal(err) } - - if tc.template != nil { - _, err = c.AddTemplate(context.Background(), tc.template) + if !tt.OmitTemplate { + tmpl := createTemplate(name("foos"), crdNames("Foos"), targets("h1")) + _, err := c.AddTemplate(tmpl) if err != nil { t.Fatal(err) } } - - r, err := c.AddConstraint(context.Background(), tc.constraint) - if !errors.Is(err, tc.wantAddConstraintError) { - t.Fatalf("got AddConstraint() error = %v, want %v", - err, tc.wantAddConstraintError) - } - - if r == nil { - t.Fatal("got AddConstraint() == nil, want non-nil") + r, err := c.AddConstraint(ctx, tt.Constraint) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") } - - cached, err := c.GetConstraint(context.Background(), tc.constraint) - if !errors.Is(err, tc.wantGetConstraintError) { - t.Fatalf("got GetConstraint() error = %v, want %v", - err, tc.wantGetConstraintError) + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} } - if tc.wantGetConstraintError != nil { - return + if r == nil { + t.Fatal("got AddConstraint() == nil, want non-nil") } - - if diff := cmp.Diff(tc.constraint.Object["spec"], cached.Object["spec"]); diff != "" { - t.Error("cached constraint does not equal stored constraint") + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) } - - r2, err := c.RemoveConstraint(context.Background(), tc.constraint) - if err != nil { - t.Error("could not remove constraint") + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) } - - if r2 == nil { - t.Fatal("got RemoveConstraint() == nil, want non-nil") + cached, err := c.GetConstraint(tt.Constraint) + if err == nil && tt.ErrorExpected { + t.Error("retrieved constraint when error was expected") } - - if r2.HandledCount() != 1 { - t.Error("more targets handled than expected") + if err != nil && !tt.ErrorExpected { + t.Error("could not retrieve constraint when error was expected") } + if !tt.ErrorExpected { + if !constraintlib.SemanticEqual(cached, tt.Constraint) { + t.Error("cached constraint does not equal stored constraint") + } + r2, err := c.RemoveConstraint(ctx, tt.Constraint) + if err != nil { + t.Error("could not remove constraint") + } - if _, err := c.GetConstraint(context.Background(), tc.constraint); err == nil { - t.Error("constraint not cleared from cache") + if r2 == nil { + t.Fatal("got RemoveConstraint() == nil, want non-nil") + } + if r2.HandledCount() != 1 { + t.Error("more targets handled than expected") + } + if _, err := c.GetConstraint(tt.Constraint); err == nil { + t.Error("constraint not cleared from cache") + } } }) } } func TestRemoveConstraint(t *testing.T) { - tcs := []struct { - name string - template *templates.ConstraintTemplate - constraint *unstructured.Unstructured - toRemove *unstructured.Unstructured - wantHandled map[string]bool - wantError error + tc := []struct { + Name string + Constraint *unstructured.Unstructured + OmitTemplate bool + ErrorExpected bool + ExpectedErrorType string }{ { - name: "Good Constraint", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("Foos", "foo", nil, nil), - toRemove: newConstraint("Foos", "foo", nil, nil), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Good Constraint", + Constraint: newConstraint("Foos", "foo", nil, nil), }, { - name: "No name", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("Foos", "foo", nil, nil), - toRemove: newConstraint("Foos", "", nil, nil), - wantHandled: nil, - wantError: ErrInvalidConstraint, + Name: "No Name", + Constraint: newConstraint("Foos", "", nil, nil), + ErrorExpected: true, }, { - name: "No Kind", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - constraint: newConstraint("Foos", "foo", nil, nil), - toRemove: newConstraint("", "foo", nil, nil), - wantHandled: nil, - wantError: ErrInvalidConstraint, + Name: "No Kind", + Constraint: newConstraint("", "foo", nil, nil), + ErrorExpected: true, }, { - name: "No Template", - toRemove: newConstraint("Foos", "foo", nil, nil), - wantHandled: nil, - wantError: ErrMissingConstraintTemplate, + Name: "No Template", + Constraint: newConstraint("Foo", "foo", nil, nil), + OmitTemplate: true, + ErrorExpected: true, }, { - name: "No Constraint", - template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), - toRemove: newConstraint("Foos", "bar", nil, nil), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Unrecognized Constraint", + Constraint: newConstraint("Bar", "bar", nil, nil), + OmitTemplate: true, + ErrorExpected: true, + ExpectedErrorType: "*client.UnrecognizedConstraintError", }, } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { ctx := context.Background() d := local.New() @@ -963,34 +930,38 @@ func TestRemoveConstraint(t *testing.T) { if err != nil { t.Fatal(err) } - - if tc.template != nil { - _, err = c.AddTemplate(ctx, tc.template) + if !tt.OmitTemplate { + tmpl := createTemplate(name("foos"), crdNames("Foos"), targets("h1")) + _, err := c.AddTemplate(tmpl) if err != nil { t.Fatal(err) } } - - if tc.constraint != nil { - _, err = c.AddConstraint(ctx, tc.constraint) - if err != nil { - t.Fatal(err) - } + r, err := c.RemoveConstraint(ctx, tt.Constraint) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) } - - r, err := c.RemoveConstraint(context.Background(), tc.toRemove) - - if !errors.Is(err, tc.wantError) { - t.Errorf("got RemoveConstraint error = %v, want %v", - err, tc.wantError) + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") + } + if tt.ErrorExpected && tt.ExpectedErrorType != "" && reflect.TypeOf(err).String() != tt.ExpectedErrorType { + t.Errorf("err type = %s; want %s", reflect.TypeOf(err).String(), tt.ExpectedErrorType) + } + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} } if r == nil { t.Fatal("got RemoveConstraint() == nil, want non-nil") } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + } + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) } }) } @@ -1006,385 +977,128 @@ violation[{"msg": "msg"}] { } ` - tcs := []struct { - name string - allowedFields []string - handler TargetHandler - template *templates.ConstraintTemplate - wantHandled map[string]bool - wantError error + tc := []struct { + Name string + Handler TargetHandler + Template *templates.ConstraintTemplate + ErrorExpected bool + InvAllowed bool }{ { - name: "Inventory Not Used", - allowedFields: []string{}, - handler: &badHandler{Name: "h1", HasLib: true}, - template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Inventory Not Used", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), + ErrorExpected: false, }, { - name: "Inventory used but not allowed", - allowedFields: []string{}, - handler: &badHandler{Name: "h1", HasLib: true}, - template: inventoryTempl, - wantHandled: nil, - wantError: ErrInvalidConstraintTemplate, + Name: "Inventory Used", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: inventoryTempl, + ErrorExpected: true, }, { - name: "Inventory used and allowed", - allowedFields: []string{"inventory"}, - handler: &badHandler{Name: "h1", HasLib: true}, - template: inventoryTempl, - wantHandled: map[string]bool{"h1": true}, - wantError: nil, + Name: "Inventory Used But Allowed", + Handler: &badHandler{Name: "h1", HasLib: true}, + Template: inventoryTempl, + ErrorExpected: false, + InvAllowed: true, }, } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { - t.Fatal(err) + t.Fatalf("Could not create backend: %s", err) + } + f := AllowedDataFields() + if tt.InvAllowed { + f = AllowedDataFields("inventory") } - c, err := b.NewClient(Targets(tc.handler), AllowedDataFields(tc.allowedFields...)) + c, err := b.NewClient(Targets(tt.Handler), f) if err != nil { t.Fatal(err) } - - r, err := c.AddTemplate(context.Background(), tc.template) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got AddTemplate() error = %v, want %v", - err, tc.wantError) + r, err := c.AddTemplate(tt.Template) + if err != nil && !tt.ErrorExpected { + t.Errorf("err = %v; want nil", err) + } + if err == nil && tt.ErrorExpected { + t.Error("err = nil; want non-nil") + } + expectedCount := 0 + expectedHandled := make(map[string]bool) + if !tt.ErrorExpected { + expectedCount = 1 + expectedHandled = map[string]bool{"h1": true} } if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } - - if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { - t.Error(diff) + if r.HandledCount() != expectedCount { + t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + } + if !reflect.DeepEqual(r.Handled, expectedHandled) { + t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) } }) } } func TestAllowedDataFieldsIntersection(t *testing.T) { - tcs := []struct { - name string - allowed Opt - want []string - wantError error + tc := []struct { + Name string + Allowed Opt + Expected []string + wantError bool }{ { - name: "No AllowedDataFields specified", - want: []string{"inventory"}, + Name: "No AllowedDataFields specified", + Expected: []string{"inventory"}, }, { - name: "Empty AllowedDataFields Used", - allowed: AllowedDataFields(), - want: nil, + Name: "Empty AllowedDataFields Used", + Allowed: AllowedDataFields(), + Expected: nil, }, { - name: "Inventory Used", - allowed: AllowedDataFields("inventory"), - want: []string{"inventory"}, + Name: "Inventory Used", + Allowed: AllowedDataFields("inventory"), + Expected: []string{"inventory"}, }, { - name: "Invalid Data Field", - allowed: AllowedDataFields("no_overlap"), - want: []string{}, - wantError: ErrCreatingClient, + Name: "Invalid Data Field", + Allowed: AllowedDataFields("no_overlap"), + Expected: []string{}, + wantError: true, }, } - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + for _, tt := range tc { + t.Run(tt.Name, func(t *testing.T) { d := local.New() - b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - opts := []Opt{Targets(&badHandler{Name: "h1", HasLib: true})} - if tc.allowed != nil { - opts = append(opts, tc.allowed) + if tt.Allowed != nil { + opts = append(opts, tt.Allowed) } c, err := b.NewClient(opts...) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got NewClient() error = %v, want %v", - err, tc.wantError) - } - - if tc.wantError != nil { + if tt.wantError { + if err == nil { + t.Fatalf("Expectd error, got nil") + } return } - - if diff := cmp.Diff(tc.want, c.allowedDataFields); diff != "" { - t.Error(diff) - } - }) - } -} - -func TestClient_CreateCRD(t *testing.T) { - testCases := []struct { - name string - targets []TargetHandler - template *templates.ConstraintTemplate - want *apiextensions.CustomResourceDefinition - wantErr error - }{ - { - name: "nil", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: nil, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "empty", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{}, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "no CRD kind", - targets: []TargetHandler{&handler{}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "name-kind mismatch", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Bar", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - Rego: `package foo - -violation[msg] {msg := "always"}`, - }}, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "no targets", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "wrong target", - targets: []TargetHandler{&badHandler{Name: "handler.1", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler.2", - }}, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "no rego", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - }}, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "empty rego package", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - Rego: `package foo`, - }}, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "multiple targets", - targets: []TargetHandler{ - &badHandler{Name: "handler", HasLib: true}, - &badHandler{Name: "handler.2", HasLib: true}, - }, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - Rego: `package foo - -violation[msg] {msg := "always"}`, - }, { - Target: "handler.2", - Rego: `package foo - -violation[msg] {msg := "always"}`, - }}, - }, - }, - want: nil, - wantErr: ErrInvalidConstraintTemplate, - }, - { - name: "minimal working", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - Rego: `package foo - -violation[msg] {msg := "always"}`, - }}, - }, - }, - want: &apiextensions.CustomResourceDefinition{ - ObjectMeta: v1.ObjectMeta{ - Name: "foo.constraints.gatekeeper.sh", - Labels: map[string]string{"gatekeeper.sh/constraint": "yes"}, - }, - Spec: apiextensions.CustomResourceDefinitionSpec{ - Group: "constraints.gatekeeper.sh", - Version: "v1beta1", - Names: apiextensions.CustomResourceDefinitionNames{ - Plural: "foo", - Singular: "foo", - Kind: "Foo", - ListKind: "FooList", - Categories: []string{"constraint", "constraints"}, - }, - Scope: apiextensions.ClusterScoped, - Subresources: &apiextensions.CustomResourceSubresources{ - Status: &apiextensions.CustomResourceSubresourceStatus{}, - }, - Versions: []apiextensions.CustomResourceDefinitionVersion{{ - Name: "v1beta1", Served: true, Storage: true, - }, { - Name: "v1alpha1", Served: true, - }}, - Conversion: &apiextensions.CustomResourceConversion{ - Strategy: apiextensions.NoneConverter, - }, - PreserveUnknownFields: pointer.BoolPtr(false), - }, - Status: apiextensions.CustomResourceDefinitionStatus{ - StoredVersions: []string{"v1beta1"}, - }, - }, - wantErr: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - - d := local.New() - - b, err := NewBackend(Driver(d)) - if err != nil { - t.Fatal(err) - } - - c, err := b.NewClient(Targets(tc.targets...)) if err != nil { t.Fatal(err) } - - t.Log(c.targets) - - got, err := c.CreateCRD(ctx, tc.template) - - if !errors.Is(err, tc.wantErr) { - t.Fatalf("got CreateTemplate() error = %v, want %v", - err, tc.wantErr) - } - - if diff := cmp.Diff(tc.want, got, - cmpopts.IgnoreFields(apiextensions.CustomResourceDefinitionSpec{}, "Validation")); diff != "" { - t.Error(diff) + if !reflect.DeepEqual(c.allowedDataFields, tt.Expected) { + t.Errorf("c.allowedDataFields = %v; want %v", c.allowedDataFields, tt.Expected) } }) } diff --git a/constraint/pkg/client/drivers/local/local_test.go b/constraint/pkg/client/drivers/local/local_test.go index c0b353b47..747469a7b 100644 --- a/constraint/pkg/client/drivers/local/local_test.go +++ b/constraint/pkg/client/drivers/local/local_test.go @@ -84,6 +84,7 @@ func (tt *compositeTestCase) run(t *testing.T) { for idx, a := range tt.Actions { t.Run(fmt.Sprintf("action idx %d", idx), func(t *testing.T) { ctx := context.Background() + switch a.Op { case addModule: for _, r := range a.Rules { @@ -140,7 +141,7 @@ func (tt *compositeTestCase) run(t *testing.T) { evalPath = a.EvalPath } - res, _, err := d.eval(context.Background(), evalPath, nil, &drivers.QueryCfg{}) + res, _, err := d.eval(ctx, evalPath, nil, &drivers.QueryCfg{}) if err != nil { t.Errorf("Eval error: %s", err) } @@ -397,6 +398,8 @@ func TestPutModule(t *testing.T) { } for _, tt := range tc { t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() + dr := New() d, ok := dr.(*driver) if !ok { @@ -404,7 +407,9 @@ func TestPutModule(t *testing.T) { } for _, r := range tt.Rules { - err := d.PutModule(context.Background(), r.Path, r.Content) + ctx := context.Background() + + err := d.PutModule(ctx, r.Path, r.Content) if (err == nil) && tt.ErrorExpected { t.Fatalf("err = nil; want non-nil") } @@ -412,7 +417,7 @@ func TestPutModule(t *testing.T) { t.Fatalf("err = \"%s\"; want nil", err) } } - res, _, err := d.eval(context.Background(), "data.hello.r[a]", nil, &drivers.QueryCfg{}) + res, _, err := d.eval(ctx, "data.hello.r[a]", nil, &drivers.QueryCfg{}) if err != nil { t.Errorf("Eval error: %s", err) } @@ -502,6 +507,8 @@ func TestPutData(t *testing.T) { } for _, tt := range tc { t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() + dr := New() d, ok := dr.(*driver) if !ok { @@ -510,14 +517,14 @@ func TestPutData(t *testing.T) { for _, data := range tt.Data { for k, v := range data { - err := d.PutData(context.Background(), k, v) + err := d.PutData(ctx, k, v) if (err == nil) && tt.ErrorExpected { t.Fatalf("err = nil; want non-nil") } if (err != nil) && !tt.ErrorExpected { t.Fatalf("err = \"%s\"; want nil", err) } - res, _, err := d.eval(context.Background(), makeDataPath(k), nil, &drivers.QueryCfg{}) + res, _, err := d.eval(ctx, makeDataPath(k), nil, &drivers.QueryCfg{}) if err != nil { t.Errorf("Eval error: %s", err) } @@ -578,6 +585,8 @@ func TestDeleteData(t *testing.T) { } for _, tt := range tc { t.Run(tt.Name, func(t *testing.T) { + ctx := context.Background() + dr := New() d, ok := dr.(*driver) if !ok { @@ -589,14 +598,14 @@ func TestDeleteData(t *testing.T) { for k, v := range data { switch a.Op { case addData: - err := d.PutData(context.Background(), k, v) + err := d.PutData(ctx, k, v) if (err == nil) && a.ErrorExpected { t.Fatalf("PUT err = nil; want non-nil") } if (err != nil) && !a.ErrorExpected { t.Fatalf("PUT err = \"%s\"; want nil", err) } - res, _, err := d.eval(context.Background(), makeDataPath(k), nil, &drivers.QueryCfg{}) + res, _, err := d.eval(ctx, makeDataPath(k), nil, &drivers.QueryCfg{}) if err != nil { t.Errorf("Eval error: %s", err) } @@ -607,7 +616,7 @@ func TestDeleteData(t *testing.T) { t.Errorf("%v != %v", v, res[0].Expressions[0].Value) } case deleteData: - b, err := d.DeleteData(context.Background(), k) + b, err := d.DeleteData(ctx, k) if (err == nil) && a.ErrorExpected { t.Fatalf("DELETE err = nil; want non-nil") } @@ -617,7 +626,7 @@ func TestDeleteData(t *testing.T) { if b != a.ExpectedBool { t.Fatalf("DeleteModule(\"%s\") = %t; want %t", k, b, a.ExpectedBool) } - res, _, err := d.eval(context.Background(), makeDataPath(k), nil, &drivers.QueryCfg{}) + res, _, err := d.eval(ctx, makeDataPath(k), nil, &drivers.QueryCfg{}) if err != nil { t.Errorf("Eval error: %s", err) } @@ -668,18 +677,20 @@ func TestQuery(t *testing.T) { } t.Run("Parse Response", func(t *testing.T) { + ctx := context.Background() + d := New() for i, v := range intResponses { - if err := d.PutData(context.Background(), fmt.Sprintf("/constraints/%d", i), v); err != nil { + if err := d.PutData(ctx, fmt.Sprintf("/constraints/%d", i), v); err != nil { t.Fatal(err) } } - if err := d.PutModule(context.Background(), "test", `package hooks violation[r] { r = data.constraints[_] }`); err != nil { + if err := d.PutModule(ctx, "test", `package hooks violation[r] { r = data.constraints[_] }`); err != nil { t.Fatal(err) } - res, err := d.Query(context.Background(), "hooks.violation", nil) + res, err := d.Query(ctx, "hooks.violation", nil) if err != nil { t.Fatal(err) } diff --git a/constraint/pkg/client/drivers/remote/remote_test.go b/constraint/pkg/client/drivers/remote/remote_test.go index 44a2e4f17..88913c65f 100644 --- a/constraint/pkg/client/drivers/remote/remote_test.go +++ b/constraint/pkg/client/drivers/remote/remote_test.go @@ -76,8 +76,10 @@ const response = ` func TestQuery(t *testing.T) { t.Run("Parse Response", func(t *testing.T) { + ctx := context.Background() + d := driver{opa: newTestClient(response)} - res, err := d.Query(context.Background(), "random", nil) + res, err := d.Query(ctx, "random", nil) if err != nil { t.Fatal(err) } diff --git a/constraint/pkg/client/e2e_test.go b/constraint/pkg/client/e2e_test.go index 804af7a96..d206bebb1 100644 --- a/constraint/pkg/client/e2e_test.go +++ b/constraint/pkg/client/e2e_test.go @@ -106,24 +106,25 @@ var denyAllCases = []struct { libs: []string{denyTemplateWithLibLib}, }} -func newTestClient() (*Client, error) { +func newTestClient(ctx context.Context) (*Client, error) { d := local.New() b, err := NewBackend(Driver(d)) if err != nil { return nil, err } - return b.NewClient(Targets(&handler{})) + return b.NewClient(ctx, Targets(&handler{})) } func TestE2EAddTemplate(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatal(err) @@ -135,12 +136,13 @@ func TestE2EAddTemplate(t *testing.T) { func TestE2EDenyAll(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -172,12 +174,13 @@ func TestE2EDenyAll(t *testing.T) { func TestE2EAudit(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -214,12 +217,13 @@ func TestE2EAudit(t *testing.T) { func TestE2EAuditX2(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -271,12 +275,13 @@ func TestE2EAutoreject(t *testing.T) { // Constraint. for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -346,12 +351,13 @@ func TestE2EAutoreject(t *testing.T) { func TestE2ERemoveConstraint(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -404,12 +410,13 @@ func TestE2ERemoveConstraint(t *testing.T) { func TestE2ERemoveTemplate(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() tmpl := newConstraintTemplate("Foo", denyTemplateRego) _, err = c.AddTemplate(ctx, tmpl) if err != nil { @@ -461,12 +468,13 @@ func TestE2ERemoveTemplate(t *testing.T) { func TestE2ETracingOff(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -495,12 +503,13 @@ func TestE2ETracingOff(t *testing.T) { func TestE2ETracingOn(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -529,12 +538,13 @@ func TestE2ETracingOn(t *testing.T) { func TestE2EAuditTracingOn(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -571,12 +581,13 @@ func TestE2EAuditTracingOn(t *testing.T) { func TestE2EAuditTracingOff(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) @@ -613,12 +624,13 @@ func TestE2EAuditTracingOff(t *testing.T) { func TestE2EDryrunAll(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", `package foo violation[{"msg": "DRYRUN", "details": {}}] { "always" == "always" @@ -654,12 +666,13 @@ violation[{"msg": "DRYRUN", "details": {}}] { func TestE2EDenyByParameter(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - c, err := newTestClient() + ctx := context.Background() + + c, err := newTestClient(ctx) if err != nil { t.Fatal(err) } - ctx := context.Background() _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", `package foo violation[{"msg": "DENIED", "details": {}}] { input.parameters.name == input.review.Name diff --git a/constraint/pkg/client/regolib/rego_test.go b/constraint/pkg/client/regolib/rego_test.go index 808180a4e..a3910f2be 100644 --- a/constraint/pkg/client/regolib/rego_test.go +++ b/constraint/pkg/client/regolib/rego_test.go @@ -22,6 +22,8 @@ func TestRegoExecutes(t *testing.T) { } for _, tt := range tc { t.Run(tt.Template.Name(), func(t *testing.T) { + ctx := context.Background() + b := &bytes.Buffer{} if err := tt.Template.Execute(b, map[string]string{"Target": "foo"}); err != nil { t.Fatalf("Could not execute template: %s", tt.Template.Name()) @@ -31,7 +33,7 @@ func TestRegoExecutes(t *testing.T) { t.Fatalf("Could not parse rego for template %s: %s", tt.Template.Name(), err) } r := rego.New(rego.Query(fmt.Sprintf("data.hooks.foo.%s", strings.ToLower(tt.Template.Name()))), rego.Compiler(compiler)) - if _, err := r.Eval(context.Background()); err != nil { + if _, err := r.Eval(ctx); err != nil { t.Fatalf("Could not execute rego for template %s: %s", tt.Template.Name(), err) } }) From 8fb4e8bb4b039b76a2f7ae8f603abb26c4cc45a5 Mon Sep 17 00:00:00 2001 From: Will Beason Date: Thu, 2 Dec 2021 08:30:58 -0800 Subject: [PATCH 2/3] Remove unnecessary contexts There were a lot of Contexts we asked for in interfaces, but didn't actually use. Since this pollutes call sites and gives the false impression that these calls are cancellable/etc, we shouldn't have them. This commit removes these unnecessary contexts. Signed-off-by: Will Beason --- constraint/pkg/client/backend.go | 33 +- constraint/pkg/client/client.go | 311 ++++++++---------- .../client/client_addtemplate_bench_test.go | 6 +- constraint/pkg/client/drivers/interface.go | 10 +- constraint/pkg/client/drivers/local/local.go | 23 +- .../drivers/local/local_benchmark_test.go | 4 +- .../pkg/client/drivers/local/local_test.go | 14 +- .../client/drivers/local/local_unit_test.go | 32 +- .../pkg/client/drivers/remote/remote.go | 18 +- constraint/pkg/client/e2e_test.go | 58 ++-- 10 files changed, 229 insertions(+), 280 deletions(-) diff --git a/constraint/pkg/client/backend.go b/constraint/pkg/client/backend.go index 4a58d165c..ebe64b69f 100644 --- a/constraint/pkg/client/backend.go +++ b/constraint/pkg/client/backend.go @@ -1,7 +1,7 @@ package client import ( - "context" + "errors" "fmt" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" @@ -23,10 +23,8 @@ func Driver(d drivers.Driver) BackendOpt { } } -// NewBackend creates a new backend. A backend could be a connection to a remote -// server or a new local OPA instance. -// -// A BackendOpt setting driver, such as Driver() must be passed. +// NewBackend creates a new backend. A backend could be a connection to a remote server or +// a new local OPA instance. func NewBackend(opts ...BackendOpt) (*Backend, error) { helper, err := newCRDHelper() if err != nil { @@ -38,17 +36,16 @@ func NewBackend(opts ...BackendOpt) (*Backend, error) { } if b.driver == nil { - return nil, fmt.Errorf("%w: no driver supplied", ErrCreatingBackend) + return nil, errors.New("no driver supplied to the backend") } return b, nil } // NewClient creates a new client for the supplied backend. -func (b *Backend) NewClient(ctx context.Context, opts ...Opt) (*Client, error) { +func (b *Backend) NewClient(opts ...Opt) (*Client, error) { if b.hasClient { - return nil, fmt.Errorf("%w: only one client per backend is allowed", - ErrCreatingClient) + return nil, errors.New("currently only one client per backend is supported") } var fields []string @@ -63,32 +60,32 @@ func (b *Backend) NewClient(ctx context.Context, opts ...Opt) (*Client, error) { allowedDataFields: fields, } + var errs Errors for _, opt := range opts { if err := opt(c); err != nil { - return nil, err + errs = append(errs, err) } } + if len(errs) > 0 { + return nil, errs + } for _, field := range c.allowedDataFields { if !validDataFields[field] { - return nil, fmt.Errorf("%w: invalid data field %q; allowed fields are: %v", - ErrCreatingClient, field, validDataFields) + return nil, fmt.Errorf("invalid data field %s", field) } } if len(c.targets) == 0 { - return nil, fmt.Errorf("%w: must specify at least one target with client.Targets", - ErrCreatingClient) + return nil, errors.New("no targets registered: please register a target via client.Targets()") } - if err := b.driver.Init(ctx); err != nil { + if err := b.driver.Init(); err != nil { return nil, err } - if err := c.init(ctx); err != nil { + if err := c.init(); err != nil { return nil, err } - - b.hasClient = true return c, nil } diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index 4e4fd2750..d3aef1a47 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "path" - "sort" + "regexp" "strings" "sync" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" - "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" "github.com/open-policy-agent/frameworks/constraint/pkg/client/regolib" constraintlib "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" @@ -25,6 +24,46 @@ import ( const constraintGroup = "constraints.gatekeeper.sh" +type Opt func(*Client) error + +// Client options + +var targetNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9.]*$`) + +func Targets(ts ...TargetHandler) Opt { + return func(c *Client) error { + var errs Errors + handlers := make(map[string]TargetHandler, len(ts)) + for _, t := range ts { + name := t.GetName() + switch { + case name == "": + errs = append(errs, errors.New("invalid target: a target is returning an empty string for GetName()")) + case !targetNameRegex.MatchString(name): + errs = append(errs, fmt.Errorf("target name %q is not of the form %q", name, targetNameRegex.String())) + default: + handlers[name] = t + } + } + c.targets = handlers + + if len(errs) > 0 { + return errs + } + return nil + } +} + +// AllowedDataFields sets the fields under `data` that Rego in ConstraintTemplates +// can access. If unset, all fields can be accessed. Only fields recognized by +// the system can be enabled. +func AllowedDataFields(fields ...string) Opt { + return func(c *Client) error { + c.allowedDataFields = fields + return nil + } +} + type templateEntry struct { template *templates.ConstraintTemplate CRD *apiextensions.CustomResourceDefinition @@ -32,14 +71,11 @@ type templateEntry struct { } type Client struct { - backend *Backend - targets map[string]TargetHandler - - // mtx guards access to both templates and constraints. - mtx sync.RWMutex - templates map[templateKey]*templateEntry - constraints map[schema.GroupKind]map[string]*unstructured.Unstructured - + backend *Backend + targets map[string]TargetHandler + constraintsMux sync.RWMutex + templates map[templateKey]*templateEntry + constraints map[schema.GroupKind]map[string]*unstructured.Unstructured allowedDataFields []string } @@ -76,7 +112,7 @@ func (c *Client) AddData(ctx context.Context, data interface{}) (*types.Response if len(errMap) == 0 { return resp, nil } - return resp, &errMap + return resp, errMap } // RemoveData removes data from OPA for every target that can handle the data. @@ -103,7 +139,7 @@ func (c *Client) RemoveData(ctx context.Context, data interface{}) (*types.Respo if len(errMap) == 0 { return resp, nil } - return resp, &errMap + return resp, errMap } // createTemplatePath returns the package path for a given template: templates... @@ -122,14 +158,14 @@ func (c *Client) validateTargets(templ *templates.ConstraintTemplate) (*template return nil, nil, err } + if len(templ.Spec.Targets) != 1 { + return nil, nil, fmt.Errorf("expected exactly 1 item in targets, got %v", templ.Spec.Targets) + } + targetSpec := &templ.Spec.Targets[0] targetHandler, found := c.targets[targetSpec.Target] - if !found { - knownTargets := c.knownTargets() - - return nil, nil, fmt.Errorf("%w: target %s not recognized, known targets %v", - ErrInvalidConstraintTemplate, targetSpec.Target, knownTargets) + return nil, nil, fmt.Errorf("target %s not recognized", targetSpec.Target) } return targetSpec, targetHandler, nil @@ -161,10 +197,9 @@ func (a *rawCTArtifacts) Key() templateKey { // createRawTemplateArtifacts creates the "free" artifacts for a template, avoiding more // complex tasks like rewriting Rego. Provides minimal validation. func (c *Client) createRawTemplateArtifacts(templ *templates.ConstraintTemplate) (*rawCTArtifacts, error) { - if templ.GetName() == "" { - return nil, fmt.Errorf("%w: missing name", ErrInvalidConstraintTemplate) + if templ.ObjectMeta.Name == "" { + return nil, errors.New("invalid Template: missing name") } - return &rawCTArtifacts{template: templ}, nil } @@ -210,16 +245,8 @@ func (c *Client) createBasicTemplateArtifacts(templ *templates.ConstraintTemplat if err != nil { return nil, err } - - kind := templ.Spec.CRD.Spec.Names.Kind - if kind == "" { - return nil, fmt.Errorf("%w: ConstraintTemplate %q does not specify CRD Kind", - ErrInvalidConstraintTemplate, templ.GetName()) - } - - if !strings.EqualFold(templ.ObjectMeta.Name, kind) { - return nil, fmt.Errorf("%w: the ConstraintTemplate's name %q is not equal to the lowercase of CRD's Kind: %q", - ErrInvalidConstraintTemplate, templ.ObjectMeta.Name, strings.ToLower(kind)) + if !strings.EqualFold(templ.ObjectMeta.Name, templ.Spec.CRD.Spec.Names.Kind) { + return nil, fmt.Errorf("the ConstraintTemplate's name %q is not equal to the lowercase of CRD's Kind: %q", templ.ObjectMeta.Name, strings.ToLower(templ.Spec.CRD.Spec.Names.Kind)) } targetSpec, targetHandler, err := c.validateTargets(templ) @@ -231,14 +258,12 @@ func (c *Client) createBasicTemplateArtifacts(templ *templates.ConstraintTemplat if err != nil { return nil, err } - crd, err := c.backend.crd.createCRD(templ, sch) if err != nil { return nil, err } - if err = c.backend.crd.validateCRD(crd); err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) + return nil, err } entryPointPath := createTemplatePath(targetHandler.GetName(), templ.Spec.CRD.Spec.Names.Kind) @@ -267,63 +292,55 @@ func (c *Client) createTemplateArtifacts(templ *templates.ConstraintTemplate) (* } libPrefix := templateLibPrefix(artifacts.targetHandler.GetName(), artifacts.crd.Spec.Names.Kind) - rr, err := regorewriter.New( regorewriter.NewPackagePrefixer(libPrefix), []string{"data.lib"}, externs) if err != nil { - return nil, fmt.Errorf("creating rego rewriter: %w", err) + return nil, err } entryPoint, err := parseModule(artifacts.namePrefix, artifacts.targetSpec.Rego) if err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) + return nil, err } - if entryPoint == nil { - return nil, fmt.Errorf("%w: failed to parse module for unknown reason", - ErrInvalidConstraintTemplate) + return nil, fmt.Errorf("failed to parse module for unknown reason") } - if err = rewriteModulePackage(artifacts.namePrefix, entryPoint); err != nil { + if err := rewriteModulePackage(artifacts.namePrefix, entryPoint); err != nil { return nil, err } req := map[string]struct{}{"violation": {}} - if err = requireModuleRules(entryPoint, req); err != nil { - return nil, fmt.Errorf("%w: invalid rego: %v", - ErrInvalidConstraintTemplate, err) + if err := requireRulesModule(entryPoint, req); err != nil { + return nil, fmt.Errorf("invalid rego: %s", err) } rr.AddEntryPointModule(artifacts.namePrefix, entryPoint) for idx, libSrc := range artifacts.targetSpec.Libs { libPath := fmt.Sprintf(`%s["lib_%d"]`, libPrefix, idx) - if err = rr.AddLib(libPath, libSrc); err != nil { - return nil, fmt.Errorf("%w: %v", - ErrInvalidConstraintTemplate, err) + if err := rr.AddLib(libPath, libSrc); err != nil { + return nil, err } } sources, err := rr.Rewrite() if err != nil { - return nil, fmt.Errorf("%w: %v", - ErrInvalidConstraintTemplate, err) + return nil, err } var mods []string - err = sources.ForEachModule(func(m *regorewriter.Module) error { - content, err2 := m.Content() - if err2 != nil { - return err2 + if err := sources.ForEachModule(func(m *regorewriter.Module) error { + content, err := m.Content() + if err != nil { + return err } mods = append(mods, string(content)) return nil - }) - if err != nil { - return nil, fmt.Errorf("%w: %v", - ErrInvalidConstraintTemplate, err) + }); err != nil { + return nil, err } return &ctArtifacts{ @@ -333,12 +350,7 @@ func (c *Client) createTemplateArtifacts(templ *templates.ConstraintTemplate) (* } // CreateCRD creates a CRD from template. -func (c *Client) CreateCRD(_ context.Context, templ *templates.ConstraintTemplate) (*apiextensions.CustomResourceDefinition, error) { - if templ == nil { - return nil, fmt.Errorf("%w: got nil ConstraintTemplate", - ErrInvalidConstraintTemplate) - } - +func (c *Client) CreateCRD(templ *templates.ConstraintTemplate) (*apiextensions.CustomResourceDefinition, error) { artifacts, err := c.createTemplateArtifacts(templ) if err != nil { return nil, err @@ -349,7 +361,7 @@ func (c *Client) CreateCRD(_ context.Context, templ *templates.ConstraintTemplat // AddTemplate adds the template source code to OPA and registers the CRD with the client for // schema validation on calls to AddConstraint. On error, the responses return value // will still be populated so that partial results can be analyzed. -func (c *Client) AddTemplate(ctx context.Context, templ *templates.ConstraintTemplate) (*types.Responses, error) { +func (c *Client) AddTemplate(templ *templates.ConstraintTemplate) (*types.Responses, error) { resp := types.NewResponses() basicArtifacts, err := c.createBasicTemplateArtifacts(templ) @@ -358,7 +370,7 @@ func (c *Client) AddTemplate(ctx context.Context, templ *templates.ConstraintTem } // return immediately if no change - if cached, err := c.GetTemplate(ctx, templ); err == nil && cached.SemanticEqual(templ) { + if cached, err := c.GetTemplate(templ); err == nil && cached.SemanticEqual(templ) { resp.Handled[basicArtifacts.targetHandler.GetName()] = true return resp, nil } @@ -368,11 +380,11 @@ func (c *Client) AddTemplate(ctx context.Context, templ *templates.ConstraintTem return resp, err } - c.mtx.Lock() - defer c.mtx.Unlock() + c.constraintsMux.Lock() + defer c.constraintsMux.Unlock() - if err = c.backend.driver.PutModules(ctx, artifacts.namePrefix, artifacts.modules); err != nil { - return resp, fmt.Errorf("%w: %v", local.ErrCompile, err) + if err := c.backend.driver.PutModules(artifacts.namePrefix, artifacts.modules); err != nil { + return resp, err } cpy := templ.DeepCopy() @@ -401,12 +413,12 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint return resp, err } - c.mtx.Lock() - defer c.mtx.Unlock() + c.constraintsMux.Lock() + defer c.constraintsMux.Unlock() - template, err := c.getTemplateNoLock(rawArtifacts.Key()) + template, err := c.getTemplateNoLock(rawArtifacts) if err != nil { - if errors.Is(err, ErrMissingConstraintTemplate) { + if IsMissingTemplateError(err) { return resp, nil } return resp, err @@ -417,7 +429,7 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint return resp, err } - if _, err := c.backend.driver.DeleteModules(ctx, artifacts.namePrefix); err != nil { + if _, err := c.backend.driver.DeleteModules(artifacts.namePrefix); err != nil { return resp, err } @@ -438,24 +450,22 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint } // GetTemplate gets the currently recognized template. -func (c *Client) GetTemplate(_ context.Context, templ *templates.ConstraintTemplate) (*templates.ConstraintTemplate, error) { +func (c *Client) GetTemplate(templ *templates.ConstraintTemplate) (*templates.ConstraintTemplate, error) { artifacts, err := c.createRawTemplateArtifacts(templ) if err != nil { return nil, err } - c.mtx.RLock() - defer c.mtx.RUnlock() - return c.getTemplateNoLock(artifacts.Key()) + c.constraintsMux.Lock() + defer c.constraintsMux.Unlock() + return c.getTemplateNoLock(artifacts) } -func (c *Client) getTemplateNoLock(key templateKey) (*templates.ConstraintTemplate, error) { - t, ok := c.templates[key] +func (c *Client) getTemplateNoLock(artifacts keyableArtifact) (*templates.ConstraintTemplate, error) { + t, ok := c.templates[artifacts.Key()] if !ok { - return nil, fmt.Errorf("%w: template for %q not found", - ErrMissingConstraintTemplate, key) + return nil, NewMissingTemplateError(string(artifacts.Key())) } - ret := t.template.DeepCopy() return ret, nil } @@ -464,20 +474,15 @@ func (c *Client) getTemplateNoLock(key templateKey) (*templates.ConstraintTempla // for each target: cluster.... func createConstraintSubPath(constraint *unstructured.Unstructured) (string, error) { if constraint.GetName() == "" { - return "", fmt.Errorf("%w: missing name", ErrInvalidConstraint) + return "", errors.New("invalid Constraint: missing name") } - gvk := constraint.GroupVersionKind() if gvk.Group == "" { - return "", fmt.Errorf("%w: empty group for constrant %q", - ErrInvalidConstraint, constraint.GetName()) + return "", fmt.Errorf("empty group for the constrant named %s", constraint.GetName()) } - if gvk.Kind == "" { - return "", fmt.Errorf("%w: empty kind for constraint %q", - ErrInvalidConstraint, constraint.GetName()) + return "", fmt.Errorf("empty kind for the constraint named %s", constraint.GetName()) } - return path.Join(createConstraintGKSubPath(gvk.GroupKind()), constraint.GetName()), nil } @@ -510,31 +515,19 @@ func constraintPathMerge(target, subpath string) string { func (c *Client) getTemplateEntry(constraint *unstructured.Unstructured, lock bool) (*templateEntry, error) { kind := constraint.GetKind() if kind == "" { - return nil, fmt.Errorf("%w: kind missing from Constraint %q", - ErrInvalidConstraint, constraint.GetName()) + return nil, fmt.Errorf("kind missing from Constraint %q", constraint.GetName()) } - if constraint.GroupVersionKind().Group != constraintGroup { - return nil, fmt.Errorf("%w: wrong API Group for Constraint %q, need %q", - ErrInvalidConstraint, constraint.GetName(), constraintGroup) + return nil, fmt.Errorf("wrong API Group for Constraint %q", constraint.GetName()) } - if lock { - c.mtx.RLock() - defer c.mtx.RUnlock() + c.constraintsMux.RLock() + defer c.constraintsMux.RUnlock() } - entry, ok := c.templates[templateKeyFromConstraint(constraint)] if !ok { - var known []string - for k := range c.templates { - known = append(known, string(k)) - } - - return nil, fmt.Errorf("%w: Constraint kind %q is not recognized, known kinds %v", - ErrMissingConstraintTemplate, kind, known) + return nil, NewUnrecognizedConstraintError(kind) } - return entry, nil } @@ -542,30 +535,25 @@ func (c *Client) getTemplateEntry(constraint *unstructured.Unstructured, lock bo // On error, the responses return value will still be populated so that // partial results can be analyzed. func (c *Client) AddConstraint(ctx context.Context, constraint *unstructured.Unstructured) (*types.Responses, error) { - c.mtx.Lock() - defer c.mtx.Unlock() - + c.constraintsMux.RLock() + defer c.constraintsMux.RUnlock() resp := types.NewResponses() errMap := make(ErrorMap) entry, err := c.getTemplateEntry(constraint, false) if err != nil { return resp, err } - subPath, err := createConstraintSubPath(constraint) if err != nil { - return resp, fmt.Errorf("creating Constraint subpath: %w", err) + return resp, err } - // return immediately if no change - cached, err := c.getConstraintNoLock(constraint) - if err == nil && constraintlib.SemanticEqual(cached, constraint) { + if cached, err := c.getConstraintNoLock(constraint); err == nil && constraintlib.SemanticEqual(cached, constraint) { for _, target := range entry.Targets { resp.Handled[target] = true } return resp, nil } - if err := c.validateConstraint(constraint, false); err != nil { return resp, err } @@ -583,21 +571,18 @@ func (c *Client) AddConstraint(ctx context.Context, constraint *unstructured.Uns } resp.Handled[target] = true } - if len(errMap) == 0 { c.constraints[constraint.GroupVersionKind().GroupKind()][subPath] = constraint.DeepCopy() return resp, nil } - - return resp, &errMap + return resp, errMap } // RemoveConstraint removes a constraint from OPA. On error, the responses // return value will still be populated so that partial results can be analyzed. func (c *Client) RemoveConstraint(ctx context.Context, constraint *unstructured.Unstructured) (*types.Responses, error) { - c.mtx.Lock() - defer c.mtx.Unlock() - + c.constraintsMux.RLock() + defer c.constraintsMux.RUnlock() return c.removeConstraintNoLock(ctx, constraint) } @@ -631,7 +616,7 @@ func (c *Client) removeConstraintNoLock(ctx context.Context, constraint *unstruc delete(c.constraints[constraint.GroupVersionKind().GroupKind()], subPath) return resp, nil } - return resp, &errMap + return resp, errMap } // getConstraintNoLock gets the currently recognized constraint without the lock. @@ -641,20 +626,17 @@ func (c *Client) getConstraintNoLock(constraint *unstructured.Unstructured) (*un return nil, err } - gk := constraint.GroupVersionKind().GroupKind() - cstr, ok := c.constraints[gk][subPath] + cstr, ok := c.constraints[constraint.GroupVersionKind().GroupKind()][subPath] if !ok { - return nil, fmt.Errorf("%w %v %q", - ErrMissingConstraint, gk, constraint.GetName()) + return nil, NewMissingConstraintError(subPath) } return cstr.DeepCopy(), nil } // GetConstraint gets the currently recognized constraint. -func (c *Client) GetConstraint(_ context.Context, constraint *unstructured.Unstructured) (*unstructured.Unstructured, error) { - c.mtx.RLock() - defer c.mtx.RUnlock() - +func (c *Client) GetConstraint(constraint *unstructured.Unstructured) (*unstructured.Unstructured, error) { + c.constraintsMux.Lock() + defer c.constraintsMux.Unlock() return c.getConstraintNoLock(constraint) } @@ -679,12 +661,12 @@ func (c *Client) validateConstraint(constraint *unstructured.Unstructured, lock // ValidateConstraint returns an error if the constraint is not recognized or does not conform to // the registered CRD for that constraint. -func (c *Client) ValidateConstraint(_ context.Context, constraint *unstructured.Unstructured) error { +func (c *Client) ValidateConstraint(constraint *unstructured.Unstructured) error { return c.validateConstraint(constraint, true) } // init initializes the OPA backend for the client. -func (c *Client) init(ctx context.Context) error { +func (c *Client) init() error { for _, t := range c.targets { hooks := fmt.Sprintf(`hooks["%s"]`, t.GetName()) templMap := map[string]string{"Target": t.GetName()} @@ -694,18 +676,16 @@ func (c *Client) init(ctx context.Context) error { return err } - builtinPath := fmt.Sprintf("%s.hooks_builtin", hooks) - err := c.backend.driver.PutModule(context.Background(), builtinPath, libBuiltin.String()) + moduleName := fmt.Sprintf("%s.hooks_builtin", hooks) + err := c.backend.driver.PutModule(moduleName, libBuiltin.String()) if err != nil { return err } libTempl := t.Library() if libTempl == nil { - return fmt.Errorf("%w: target %q has no Rego library template", - ErrCreatingClient, t.GetName()) + return fmt.Errorf("target %q has no Rego library template", t.GetName()) } - libBuf := &bytes.Buffer{} if err := libTempl.Execute(libBuf, map[string]string{ "ConstraintsRoot": fmt.Sprintf(`data.constraints["%s"].cluster["%s"]`, t.GetName(), constraintGroup), @@ -713,41 +693,30 @@ func (c *Client) init(ctx context.Context) error { }); err != nil { return err } - lib := libBuf.String() req := map[string]struct{}{ "autoreject_review": {}, "matching_reviews_and_constraints": {}, "matching_constraints": {}, } - modulePath := fmt.Sprintf("%s.library", hooks) libModule, err := parseModule(modulePath, lib) if err != nil { return fmt.Errorf("failed to parse module: %w", err) } - - err = requireModuleRules(libModule, req) - if err != nil { - return fmt.Errorf("problem with the below Rego for %q target:\n\n====%s\n====\n%w", - t.GetName(), lib, err) + if err := requireRulesModule(libModule, req); err != nil { + return fmt.Errorf("problem with the below Rego for %q target:\n\n====%s\n====\n%s", t.GetName(), lib, err) } - err = rewriteModulePackage(modulePath, libModule) if err != nil { return err } - src, err := format.Ast(libModule) if err != nil { - return fmt.Errorf("%w: could not re-format Rego source: %v", - ErrCreatingClient, err) + return fmt.Errorf("could not re-format Rego source: %v", err) } - - err = c.backend.driver.PutModule(context.TODO(), modulePath, string(src)) - if err != nil { - return fmt.Errorf("%w: error %s from compiled source:\n%s", - ErrCreatingClient, err, src) + if err := c.backend.driver.PutModule(modulePath, string(src)); err != nil { + return fmt.Errorf("error %s from compiled source:\n%s", err, src) } } @@ -756,9 +725,8 @@ func (c *Client) init(ctx context.Context) error { // Reset the state of OPA. func (c *Client) Reset(ctx context.Context) error { - c.mtx.Lock() - defer c.mtx.Unlock() - + c.constraintsMux.Lock() + defer c.constraintsMux.Unlock() for name := range c.targets { if _, err := c.backend.driver.DeleteData(ctx, fmt.Sprintf("/external/%s", name)); err != nil { return err @@ -769,7 +737,7 @@ func (c *Client) Reset(ctx context.Context) error { } for name, v := range c.templates { for _, t := range v.Targets { - if _, err := c.backend.driver.DeleteModule(ctx, fmt.Sprintf(`templates["%s"]["%s"]`, t, name)); err != nil { + if _, err := c.backend.driver.DeleteModule(fmt.Sprintf(`templates["%s"]["%s"]`, t, name)); err != nil { return err } } @@ -779,6 +747,18 @@ func (c *Client) Reset(ctx context.Context) error { return nil } +type queryCfg struct { + enableTracing bool +} + +type QueryOpt func(*queryCfg) + +func Tracing(enabled bool) QueryOpt { + return func(cfg *queryCfg) { + cfg.enableTracing = enabled + } +} + // Review makes sure the provided object satisfies all stored constraints. // On error, the responses return value will still be populated so that // partial results can be analyzed. @@ -818,7 +798,7 @@ TargetLoop: if len(errMap) == 0 { return responses, nil } - return responses, &errMap + return responses, errMap } // Audit makes sure the cached state of the system satisfies all stored constraints. @@ -851,21 +831,10 @@ TargetLoop: if len(errMap) == 0 { return responses, nil } - return responses, &errMap + return responses, errMap } // Dump dumps the state of OPA to aid in debugging. func (c *Client) Dump(ctx context.Context) (string, error) { return c.backend.driver.Dump(ctx) } - -// knownTargets returns a sorted list of currently-known target names. -func (c *Client) knownTargets() []string { - var knownTargets []string - for known := range c.targets { - knownTargets = append(knownTargets, known) - } - sort.Strings(knownTargets) - - return knownTargets -} diff --git a/constraint/pkg/client/client_addtemplate_bench_test.go b/constraint/pkg/client/client_addtemplate_bench_test.go index ad7f7028d..c7edaf43a 100644 --- a/constraint/pkg/client/client_addtemplate_bench_test.go +++ b/constraint/pkg/client/client_addtemplate_bench_test.go @@ -1,7 +1,6 @@ package client import ( - "context" "fmt" "testing" @@ -82,7 +81,6 @@ func BenchmarkClient_AddTemplate(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() - ctx := context.Background() targets := Targets(&handler{}) d := local.New() @@ -92,7 +90,7 @@ func BenchmarkClient_AddTemplate(b *testing.B) { b.Fatal(err) } - c, err := backend.NewClient(ctx, targets) + c, err := backend.NewClient(targets) if err != nil { b.Fatal(err) } @@ -100,7 +98,7 @@ func BenchmarkClient_AddTemplate(b *testing.B) { b.StartTimer() for _, ct := range cts { - _, _ = c.AddTemplate(ctx, ct) + _, _ = c.AddTemplate(ct) } } }) diff --git a/constraint/pkg/client/drivers/interface.go b/constraint/pkg/client/drivers/interface.go index 4d8918ce3..fdd06397a 100644 --- a/constraint/pkg/client/drivers/interface.go +++ b/constraint/pkg/client/drivers/interface.go @@ -19,16 +19,16 @@ func Tracing(enabled bool) QueryOpt { } type Driver interface { - Init(ctx context.Context) error + Init() error - PutModule(ctx context.Context, name string, src string) error + PutModule(name string, src string) error // PutModules upserts a number of modules under a given prefix. - PutModules(ctx context.Context, namePrefix string, srcs []string) error - DeleteModule(ctx context.Context, name string) (bool, error) + PutModules(namePrefix string, srcs []string) error + DeleteModule(name string) (bool, error) // DeleteModules deletes all modules under a given prefix and returns the // count of modules deleted. Deletion of non-existing prefix will // result in 0, nil being returned. - DeleteModules(ctx context.Context, namePrefix string) (int, error) + DeleteModules(namePrefix string) (int, error) PutData(ctx context.Context, path string, data interface{}) error DeleteData(ctx context.Context, path string) (bool, error) diff --git a/constraint/pkg/client/drivers/local/local.go b/constraint/pkg/client/drivers/local/local.go index 31e1db283..4efeabae4 100644 --- a/constraint/pkg/client/drivers/local/local.go +++ b/constraint/pkg/client/drivers/local/local.go @@ -70,7 +70,7 @@ type driver struct { providerCache *externaldata.ProviderCache } -func (d *driver) Init(ctx context.Context) error { +func (d *driver) Init() error { if d.providerCache != nil { rego.RegisterBuiltin1( ®o.Function{ @@ -168,7 +168,7 @@ func toModuleSetName(prefix string, idx int) string { return fmt.Sprintf("%s%d", toModuleSetPrefix(prefix), idx) } -func (d *driver) PutModule(ctx context.Context, name string, src string) error { +func (d *driver) PutModule(name string, src string) error { if err := d.checkModuleName(name); err != nil { return err } @@ -181,12 +181,12 @@ func (d *driver) PutModule(ctx context.Context, name string, src string) error { d.modulesMux.Lock() defer d.modulesMux.Unlock() - _, err := d.alterModules(ctx, insert, nil) + _, err := d.alterModules(insert, nil) return err } // PutModules implements drivers.Driver. -func (d *driver) PutModules(ctx context.Context, namePrefix string, srcs []string) error { +func (d *driver) PutModules(namePrefix string, srcs []string) error { if err := d.checkModuleSetName(namePrefix); err != nil { return err } @@ -210,13 +210,13 @@ func (d *driver) PutModules(ctx context.Context, namePrefix string, srcs []strin } } - _, err := d.alterModules(ctx, insert, remove) + _, err := d.alterModules(insert, remove) return err } // DeleteModule deletes a rule from OPA. Returns true if a rule was found and deleted, false // if a rule was not found, and any errors. -func (d *driver) DeleteModule(ctx context.Context, name string) (bool, error) { +func (d *driver) DeleteModule(name string) (bool, error) { if err := d.checkModuleName(name); err != nil { return false, err } @@ -228,7 +228,7 @@ func (d *driver) DeleteModule(ctx context.Context, name string) (bool, error) { return false, nil } - count, err := d.alterModules(ctx, nil, []string{name}) + count, err := d.alterModules(nil, []string{name}) return count == 1, err } @@ -236,7 +236,10 @@ func (d *driver) DeleteModule(ctx context.Context, name string) (bool, error) { // alterModules alters the modules in the driver by inserting and removing // the provided modules then returns the count of modules removed. // alterModules expects that the caller is holding the modulesMux lock. -func (d *driver) alterModules(ctx context.Context, insert insertParam, remove []string) (int, error) { +func (d *driver) alterModules(insert insertParam, remove []string) (int, error) { + // TODO(davis-haba): Remove this Context once it is no longer necessary. + ctx := context.TODO() + updatedModules := copyModules(d.modules) for _, name := range remove { delete(updatedModules, name) @@ -284,7 +287,7 @@ func (d *driver) alterModules(ctx context.Context, insert insertParam, remove [] } // DeleteModules implements drivers.Driver. -func (d *driver) DeleteModules(ctx context.Context, namePrefix string) (int, error) { +func (d *driver) DeleteModules(namePrefix string) (int, error) { if err := d.checkModuleSetName(namePrefix); err != nil { return 0, err } @@ -292,7 +295,7 @@ func (d *driver) DeleteModules(ctx context.Context, namePrefix string) (int, err d.modulesMux.Lock() defer d.modulesMux.Unlock() - return d.alterModules(ctx, nil, d.listModuleSet(namePrefix)) + return d.alterModules(nil, d.listModuleSet(namePrefix)) } // listModuleSet returns the list of names corresponding to a given module diff --git a/constraint/pkg/client/drivers/local/local_benchmark_test.go b/constraint/pkg/client/drivers/local/local_benchmark_test.go index 94fa13b35..2154ec13b 100644 --- a/constraint/pkg/client/drivers/local/local_benchmark_test.go +++ b/constraint/pkg/client/drivers/local/local_benchmark_test.go @@ -1,7 +1,6 @@ package local import ( - "context" "fmt" "testing" ) @@ -11,13 +10,12 @@ func BenchmarkDriver_PutModule(b *testing.B) { b.Run(fmt.Sprintf("%d templates", n), func(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() - ctx := context.Background() d := New() b.StartTimer() for j := 0; j < n; j++ { name := fmt.Sprintf("foo-%d", j) - err := d.PutModule(ctx, name, Module) + err := d.PutModule(name, Module) if err != nil { b.Fatal(err) } diff --git a/constraint/pkg/client/drivers/local/local_test.go b/constraint/pkg/client/drivers/local/local_test.go index 747469a7b..6e707fd68 100644 --- a/constraint/pkg/client/drivers/local/local_test.go +++ b/constraint/pkg/client/drivers/local/local_test.go @@ -88,7 +88,7 @@ func (tt *compositeTestCase) run(t *testing.T) { switch a.Op { case addModule: for _, r := range a.Rules { - err := d.PutModule(ctx, r.Path, r.Content) + err := d.PutModule(r.Path, r.Content) if (err == nil) && a.ErrorExpected { t.Fatalf("PUT err = nil; want non-nil") } @@ -99,7 +99,7 @@ func (tt *compositeTestCase) run(t *testing.T) { case deleteModule: for _, r := range a.Rules { - b, err := d.DeleteModule(ctx, r.Path) + b, err := d.DeleteModule(r.Path) if (err == nil) && a.ErrorExpected { t.Fatalf("DELETE err = nil; want non-nil") } @@ -112,7 +112,7 @@ func (tt *compositeTestCase) run(t *testing.T) { } case putModules: - err := d.PutModules(ctx, a.RuleNamePrefix, a.Rules.srcs()) + err := d.PutModules(a.RuleNamePrefix, a.Rules.srcs()) if (err == nil) && a.ErrorExpected { t.Fatalf("PutModules err = nil; want non-nil") } @@ -121,7 +121,7 @@ func (tt *compositeTestCase) run(t *testing.T) { } case deleteModules: - count, err := d.DeleteModules(ctx, a.RuleNamePrefix) + count, err := d.DeleteModules(a.RuleNamePrefix) if (err == nil) && a.ErrorExpected { t.Fatalf("DeleteModules err = nil; want non-nil") } @@ -407,9 +407,7 @@ func TestPutModule(t *testing.T) { } for _, r := range tt.Rules { - ctx := context.Background() - - err := d.PutModule(ctx, r.Path, r.Content) + err := d.PutModule(r.Path, r.Content) if (err == nil) && tt.ErrorExpected { t.Fatalf("err = nil; want non-nil") } @@ -687,7 +685,7 @@ func TestQuery(t *testing.T) { } } - if err := d.PutModule(ctx, "test", `package hooks violation[r] { r = data.constraints[_] }`); err != nil { + if err := d.PutModule("test", `package hooks violation[r] { r = data.constraints[_] }`); err != nil { t.Fatal(err) } res, err := d.Query(ctx, "hooks.violation", nil) diff --git a/constraint/pkg/client/drivers/local/local_unit_test.go b/constraint/pkg/client/drivers/local/local_unit_test.go index d3c2f7ba5..6ea5ab3c8 100644 --- a/constraint/pkg/client/drivers/local/local_unit_test.go +++ b/constraint/pkg/client/drivers/local/local_unit_test.go @@ -114,8 +114,6 @@ func TestDriver_PutModule(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New(Modules(tc.beforeModules)) dr, ok := d.(*driver) @@ -124,7 +122,7 @@ func TestDriver_PutModule(t *testing.T) { d, &driver{}) } - gotErr := d.PutModule(ctx, tc.moduleName, tc.moduleSrc) + gotErr := d.PutModule(tc.moduleName, tc.moduleSrc) if !errors.Is(gotErr, tc.wantErr) { t.Fatalf("got PutModule() error = %v, want %v", gotErr, tc.wantErr) } @@ -243,12 +241,10 @@ func TestDriver_PutModules(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New() for prefix, src := range tc.beforeModules { - err := d.PutModules(ctx, prefix, src) + err := d.PutModules(prefix, src) if err != nil { t.Fatal(err) } @@ -260,7 +256,7 @@ func TestDriver_PutModules(t *testing.T) { d, &driver{}) } - gotErr := d.PutModules(ctx, tc.prefix, tc.srcs) + gotErr := d.PutModules(tc.prefix, tc.srcs) if !errors.Is(gotErr, tc.wantErr) { t.Fatalf("got PutModules() error = %v, want %v", gotErr, tc.wantErr) } @@ -309,11 +305,9 @@ func TestDriver_PutModules_StorageErrors(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New(Storage(tc.storage)) - err := d.PutModule(ctx, "foo", Module) + err := d.PutModule("foo", Module) if tc.wantErr && err == nil { t.Fatalf("got PutModule() err %v, want error", nil) @@ -383,12 +377,10 @@ func TestDriver_DeleteModule(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New() for _, name := range tc.beforeModules { - err := d.PutModule(ctx, name, Module) + err := d.PutModule(name, Module) if err != nil { t.Fatal(err) } @@ -400,7 +392,7 @@ func TestDriver_DeleteModule(t *testing.T) { d, &driver{}) } - gotDeleted, gotErr := d.DeleteModule(ctx, tc.moduleName) + gotDeleted, gotErr := d.DeleteModule(tc.moduleName) if gotDeleted != tc.wantDeleted { t.Errorf("got DeleteModule() = %t, want %t", gotDeleted, tc.wantDeleted) } @@ -443,16 +435,14 @@ func TestDriver_DeleteModule_StorageErrors(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New(Storage(tc.storage)) - err := d.PutModule(ctx, "foo", Module) + err := d.PutModule("foo", Module) if err != nil { t.Fatal(err) } - _, err = d.DeleteModule(ctx, "foo") + _, err = d.DeleteModule("foo") if tc.wantErr && err == nil { t.Fatalf("got DeleteModule() err %v, want error", nil) @@ -543,8 +533,6 @@ func TestDriver_DeleteModules(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - d := New() for prefix, count := range tc.beforeModules { @@ -552,7 +540,7 @@ func TestDriver_DeleteModules(t *testing.T) { for i := 0; i < count; i++ { modules[i] = Module } - err := d.PutModules(ctx, prefix, modules) + err := d.PutModules(prefix, modules) if err != nil { t.Fatal(err) } @@ -564,7 +552,7 @@ func TestDriver_DeleteModules(t *testing.T) { d, &driver{}) } - gotDeleted, gotErr := d.DeleteModules(ctx, tc.prefix) + gotDeleted, gotErr := d.DeleteModules(tc.prefix) if gotDeleted != tc.wantDeleted { t.Errorf("got DeleteModules() = %v, want %v", gotDeleted, tc.wantDeleted) } diff --git a/constraint/pkg/client/drivers/remote/remote.go b/constraint/pkg/client/drivers/remote/remote.go index 9d14b8263..ed382c5b1 100644 --- a/constraint/pkg/client/drivers/remote/remote.go +++ b/constraint/pkg/client/drivers/remote/remote.go @@ -64,7 +64,7 @@ type driver struct { traceEnabled bool } -func (d *driver) Init(ctx context.Context) error { +func (d *driver) Init() error { return nil } @@ -72,18 +72,18 @@ func (d *driver) addTrace(path string) string { return path + "?explain=full&pretty=true" } -func (d *driver) PutModule(ctx context.Context, name string, src string) error { +func (d *driver) PutModule(name string, src string) error { return d.opa.InsertPolicy(name, []byte(src)) } // PutModules implements drivers.Driver. -func (d *driver) PutModules(ctx context.Context, namePrefix string, srcs []string) error { +func (d *driver) PutModules(namePrefix string, srcs []string) error { panic("not implemented") } // DeleteModule deletes a rule from OPA and returns true if a rule was found and deleted, false // if a rule was not found, and any errors. -func (d *driver) DeleteModule(ctx context.Context, name string) (bool, error) { +func (d *driver) DeleteModule(name string) (bool, error) { err := d.opa.DeletePolicy(name) if err != nil { e := &Error{} @@ -97,17 +97,17 @@ func (d *driver) DeleteModule(ctx context.Context, name string) (bool, error) { } // DeleteModules implements drivers.Driver. -func (d *driver) DeleteModules(ctx context.Context, namePrefix string) (int, error) { +func (d *driver) DeleteModules(namePrefix string) (int, error) { panic("not implemented") } -func (d *driver) PutData(ctx context.Context, path string, data interface{}) error { +func (d *driver) PutData(_ context.Context, path string, data interface{}) error { return d.opa.PutData(path, data) } // DeleteData deletes data from OPA and returns true if data was found and deleted, false // if data was not found, and any errors. -func (d *driver) DeleteData(ctx context.Context, path string) (bool, error) { +func (d *driver) DeleteData(_ context.Context, path string) (bool, error) { err := d.opa.DeleteData(path) if err != nil { e := &Error{} @@ -165,7 +165,7 @@ func makeURLPath(path string) (string, error) { return strings.Join(pieces, "/"), nil } -func (d *driver) Query(ctx context.Context, path string, input interface{}, opts ...drivers.QueryOpt) (*ctypes.Response, error) { +func (d *driver) Query(_ context.Context, path string, input interface{}, opts ...drivers.QueryOpt) (*ctypes.Response, error) { cfg := &drivers.QueryCfg{} for _, opt := range opts { opt(cfg) @@ -213,7 +213,7 @@ func (d *driver) Query(ctx context.Context, path string, input interface{}, opts return resp, nil } -func (d *driver) Dump(ctx context.Context) (string, error) { +func (d *driver) Dump(_ context.Context) (string, error) { response, err := d.opa.Query("", nil) if err != nil { return "", err diff --git a/constraint/pkg/client/e2e_test.go b/constraint/pkg/client/e2e_test.go index d206bebb1..3aacc1891 100644 --- a/constraint/pkg/client/e2e_test.go +++ b/constraint/pkg/client/e2e_test.go @@ -106,26 +106,24 @@ var denyAllCases = []struct { libs: []string{denyTemplateWithLibLib}, }} -func newTestClient(ctx context.Context) (*Client, error) { +func newTestClient() (*Client, error) { d := local.New() b, err := NewBackend(Driver(d)) if err != nil { return nil, err } - return b.NewClient(ctx, Targets(&handler{})) + return b.NewClient(Targets(&handler{})) } func TestE2EAddTemplate(t *testing.T) { for _, tc := range denyAllCases { t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatal(err) } @@ -138,12 +136,12 @@ func TestE2EDenyAll(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -176,12 +174,12 @@ func TestE2EAudit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -219,12 +217,12 @@ func TestE2EAuditX2(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -277,12 +275,12 @@ func TestE2EAutoreject(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -353,12 +351,12 @@ func TestE2ERemoveConstraint(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -412,13 +410,13 @@ func TestE2ERemoveTemplate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } tmpl := newConstraintTemplate("Foo", denyTemplateRego) - _, err = c.AddTemplate(ctx, tmpl) + _, err = c.AddTemplate(tmpl) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -470,12 +468,12 @@ func TestE2ETracingOff(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", denyTemplateRego)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", denyTemplateRego)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -505,12 +503,12 @@ func TestE2ETracingOn(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -540,12 +538,12 @@ func TestE2EAuditTracingOn(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -583,12 +581,12 @@ func TestE2EAuditTracingOff(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", tc.rego, tc.libs...)) + _, err = c.AddTemplate(newConstraintTemplate("Foo", tc.rego, tc.libs...)) if err != nil { t.Fatalf("got AddTemplate: %v", err) } @@ -626,12 +624,12 @@ func TestE2EDryrunAll(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", `package foo + _, err = c.AddTemplate(newConstraintTemplate("Foo", `package foo violation[{"msg": "DRYRUN", "details": {}}] { "always" == "always" }`)) @@ -668,12 +666,12 @@ func TestE2EDenyByParameter(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - c, err := newTestClient(ctx) + c, err := newTestClient() if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(ctx, newConstraintTemplate("Foo", `package foo + _, err = c.AddTemplate(newConstraintTemplate("Foo", `package foo violation[{"msg": "DENIED", "details": {}}] { input.parameters.name == input.review.Name }`)) From b808d17c7e7cbc195c6be6dd8b65a590787fddb4 Mon Sep 17 00:00:00 2001 From: Will Beason Date: Mon, 6 Dec 2021 09:22:31 -0800 Subject: [PATCH 3/3] Merge Signed-off-by: Will Beason --- constraint/pkg/client/backend.go | 26 +- constraint/pkg/client/client.go | 293 +++--- constraint/pkg/client/client_test.go | 1366 ++++++++++++++++---------- 3 files changed, 1001 insertions(+), 684 deletions(-) diff --git a/constraint/pkg/client/backend.go b/constraint/pkg/client/backend.go index ebe64b69f..77ca9ceed 100644 --- a/constraint/pkg/client/backend.go +++ b/constraint/pkg/client/backend.go @@ -1,7 +1,6 @@ package client import ( - "errors" "fmt" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" @@ -23,8 +22,10 @@ func Driver(d drivers.Driver) BackendOpt { } } -// NewBackend creates a new backend. A backend could be a connection to a remote server or -// a new local OPA instance. +// NewBackend creates a new backend. A backend could be a connection to a remote +// server or a new local OPA instance. +// +// A BackendOpt setting driver, such as Driver() must be passed. func NewBackend(opts ...BackendOpt) (*Backend, error) { helper, err := newCRDHelper() if err != nil { @@ -36,7 +37,7 @@ func NewBackend(opts ...BackendOpt) (*Backend, error) { } if b.driver == nil { - return nil, errors.New("no driver supplied to the backend") + return nil, fmt.Errorf("%w: no driver supplied", ErrCreatingBackend) } return b, nil @@ -45,7 +46,8 @@ func NewBackend(opts ...BackendOpt) (*Backend, error) { // NewClient creates a new client for the supplied backend. func (b *Backend) NewClient(opts ...Opt) (*Client, error) { if b.hasClient { - return nil, errors.New("currently only one client per backend is supported") + return nil, fmt.Errorf("%w: only one client per backend is allowed", + ErrCreatingClient) } var fields []string @@ -60,24 +62,22 @@ func (b *Backend) NewClient(opts ...Opt) (*Client, error) { allowedDataFields: fields, } - var errs Errors for _, opt := range opts { if err := opt(c); err != nil { - errs = append(errs, err) + return nil, err } } - if len(errs) > 0 { - return nil, errs - } for _, field := range c.allowedDataFields { if !validDataFields[field] { - return nil, fmt.Errorf("invalid data field %s", field) + return nil, fmt.Errorf("%w: invalid data field %q; allowed fields are: %v", + ErrCreatingClient, field, validDataFields) } } if len(c.targets) == 0 { - return nil, errors.New("no targets registered: please register a target via client.Targets()") + return nil, fmt.Errorf("%w: must specify at least one target with client.Targets", + ErrCreatingClient) } if err := b.driver.Init(); err != nil { @@ -87,5 +87,7 @@ func (b *Backend) NewClient(opts ...Opt) (*Client, error) { if err := c.init(); err != nil { return nil, err } + + b.hasClient = true return c, nil } diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index d3aef1a47..783a6b78e 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -6,11 +6,12 @@ import ( "errors" "fmt" "path" - "regexp" + "sort" "strings" "sync" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" + "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" "github.com/open-policy-agent/frameworks/constraint/pkg/client/regolib" constraintlib "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" @@ -24,46 +25,6 @@ import ( const constraintGroup = "constraints.gatekeeper.sh" -type Opt func(*Client) error - -// Client options - -var targetNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9.]*$`) - -func Targets(ts ...TargetHandler) Opt { - return func(c *Client) error { - var errs Errors - handlers := make(map[string]TargetHandler, len(ts)) - for _, t := range ts { - name := t.GetName() - switch { - case name == "": - errs = append(errs, errors.New("invalid target: a target is returning an empty string for GetName()")) - case !targetNameRegex.MatchString(name): - errs = append(errs, fmt.Errorf("target name %q is not of the form %q", name, targetNameRegex.String())) - default: - handlers[name] = t - } - } - c.targets = handlers - - if len(errs) > 0 { - return errs - } - return nil - } -} - -// AllowedDataFields sets the fields under `data` that Rego in ConstraintTemplates -// can access. If unset, all fields can be accessed. Only fields recognized by -// the system can be enabled. -func AllowedDataFields(fields ...string) Opt { - return func(c *Client) error { - c.allowedDataFields = fields - return nil - } -} - type templateEntry struct { template *templates.ConstraintTemplate CRD *apiextensions.CustomResourceDefinition @@ -71,11 +32,14 @@ type templateEntry struct { } type Client struct { - backend *Backend - targets map[string]TargetHandler - constraintsMux sync.RWMutex - templates map[templateKey]*templateEntry - constraints map[schema.GroupKind]map[string]*unstructured.Unstructured + backend *Backend + targets map[string]TargetHandler + + // mtx guards access to both templates and constraints. + mtx sync.RWMutex + templates map[templateKey]*templateEntry + constraints map[schema.GroupKind]map[string]*unstructured.Unstructured + allowedDataFields []string } @@ -112,7 +76,7 @@ func (c *Client) AddData(ctx context.Context, data interface{}) (*types.Response if len(errMap) == 0 { return resp, nil } - return resp, errMap + return resp, &errMap } // RemoveData removes data from OPA for every target that can handle the data. @@ -139,7 +103,7 @@ func (c *Client) RemoveData(ctx context.Context, data interface{}) (*types.Respo if len(errMap) == 0 { return resp, nil } - return resp, errMap + return resp, &errMap } // createTemplatePath returns the package path for a given template: templates... @@ -158,14 +122,14 @@ func (c *Client) validateTargets(templ *templates.ConstraintTemplate) (*template return nil, nil, err } - if len(templ.Spec.Targets) != 1 { - return nil, nil, fmt.Errorf("expected exactly 1 item in targets, got %v", templ.Spec.Targets) - } - targetSpec := &templ.Spec.Targets[0] targetHandler, found := c.targets[targetSpec.Target] + if !found { - return nil, nil, fmt.Errorf("target %s not recognized", targetSpec.Target) + knownTargets := c.knownTargets() + + return nil, nil, fmt.Errorf("%w: target %s not recognized, known targets %v", + ErrInvalidConstraintTemplate, targetSpec.Target, knownTargets) } return targetSpec, targetHandler, nil @@ -197,9 +161,10 @@ func (a *rawCTArtifacts) Key() templateKey { // createRawTemplateArtifacts creates the "free" artifacts for a template, avoiding more // complex tasks like rewriting Rego. Provides minimal validation. func (c *Client) createRawTemplateArtifacts(templ *templates.ConstraintTemplate) (*rawCTArtifacts, error) { - if templ.ObjectMeta.Name == "" { - return nil, errors.New("invalid Template: missing name") + if templ.GetName() == "" { + return nil, fmt.Errorf("%w: missing name", ErrInvalidConstraintTemplate) } + return &rawCTArtifacts{template: templ}, nil } @@ -245,8 +210,16 @@ func (c *Client) createBasicTemplateArtifacts(templ *templates.ConstraintTemplat if err != nil { return nil, err } - if !strings.EqualFold(templ.ObjectMeta.Name, templ.Spec.CRD.Spec.Names.Kind) { - return nil, fmt.Errorf("the ConstraintTemplate's name %q is not equal to the lowercase of CRD's Kind: %q", templ.ObjectMeta.Name, strings.ToLower(templ.Spec.CRD.Spec.Names.Kind)) + + kind := templ.Spec.CRD.Spec.Names.Kind + if kind == "" { + return nil, fmt.Errorf("%w: ConstraintTemplate %q does not specify CRD Kind", + ErrInvalidConstraintTemplate, templ.GetName()) + } + + if !strings.EqualFold(templ.ObjectMeta.Name, kind) { + return nil, fmt.Errorf("%w: the ConstraintTemplate's name %q is not equal to the lowercase of CRD's Kind: %q", + ErrInvalidConstraintTemplate, templ.ObjectMeta.Name, strings.ToLower(kind)) } targetSpec, targetHandler, err := c.validateTargets(templ) @@ -258,12 +231,14 @@ func (c *Client) createBasicTemplateArtifacts(templ *templates.ConstraintTemplat if err != nil { return nil, err } + crd, err := c.backend.crd.createCRD(templ, sch) if err != nil { return nil, err } + if err = c.backend.crd.validateCRD(crd); err != nil { - return nil, err + return nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } entryPointPath := createTemplatePath(targetHandler.GetName(), templ.Spec.CRD.Spec.Names.Kind) @@ -292,55 +267,63 @@ func (c *Client) createTemplateArtifacts(templ *templates.ConstraintTemplate) (* } libPrefix := templateLibPrefix(artifacts.targetHandler.GetName(), artifacts.crd.Spec.Names.Kind) + rr, err := regorewriter.New( regorewriter.NewPackagePrefixer(libPrefix), []string{"data.lib"}, externs) if err != nil { - return nil, err + return nil, fmt.Errorf("creating rego rewriter: %w", err) } entryPoint, err := parseModule(artifacts.namePrefix, artifacts.targetSpec.Rego) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } + if entryPoint == nil { - return nil, fmt.Errorf("failed to parse module for unknown reason") + return nil, fmt.Errorf("%w: failed to parse module for unknown reason", + ErrInvalidConstraintTemplate) } - if err := rewriteModulePackage(artifacts.namePrefix, entryPoint); err != nil { + if err = rewriteModulePackage(artifacts.namePrefix, entryPoint); err != nil { return nil, err } req := map[string]struct{}{"violation": {}} - if err := requireRulesModule(entryPoint, req); err != nil { - return nil, fmt.Errorf("invalid rego: %s", err) + if err = requireModuleRules(entryPoint, req); err != nil { + return nil, fmt.Errorf("%w: invalid rego: %v", + ErrInvalidConstraintTemplate, err) } rr.AddEntryPointModule(artifacts.namePrefix, entryPoint) for idx, libSrc := range artifacts.targetSpec.Libs { libPath := fmt.Sprintf(`%s["lib_%d"]`, libPrefix, idx) - if err := rr.AddLib(libPath, libSrc); err != nil { - return nil, err + if err = rr.AddLib(libPath, libSrc); err != nil { + return nil, fmt.Errorf("%w: %v", + ErrInvalidConstraintTemplate, err) } } sources, err := rr.Rewrite() if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %v", + ErrInvalidConstraintTemplate, err) } var mods []string - if err := sources.ForEachModule(func(m *regorewriter.Module) error { - content, err := m.Content() - if err != nil { - return err + err = sources.ForEachModule(func(m *regorewriter.Module) error { + content, err2 := m.Content() + if err2 != nil { + return err2 } mods = append(mods, string(content)) return nil - }); err != nil { - return nil, err + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", + ErrInvalidConstraintTemplate, err) } return &ctArtifacts{ @@ -351,6 +334,11 @@ func (c *Client) createTemplateArtifacts(templ *templates.ConstraintTemplate) (* // CreateCRD creates a CRD from template. func (c *Client) CreateCRD(templ *templates.ConstraintTemplate) (*apiextensions.CustomResourceDefinition, error) { + if templ == nil { + return nil, fmt.Errorf("%w: got nil ConstraintTemplate", + ErrInvalidConstraintTemplate) + } + artifacts, err := c.createTemplateArtifacts(templ) if err != nil { return nil, err @@ -380,11 +368,11 @@ func (c *Client) AddTemplate(templ *templates.ConstraintTemplate) (*types.Respon return resp, err } - c.constraintsMux.Lock() - defer c.constraintsMux.Unlock() + c.mtx.Lock() + defer c.mtx.Unlock() - if err := c.backend.driver.PutModules(artifacts.namePrefix, artifacts.modules); err != nil { - return resp, err + if err = c.backend.driver.PutModules(artifacts.namePrefix, artifacts.modules); err != nil { + return resp, fmt.Errorf("%w: %v", local.ErrCompile, err) } cpy := templ.DeepCopy() @@ -413,12 +401,12 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint return resp, err } - c.constraintsMux.Lock() - defer c.constraintsMux.Unlock() + c.mtx.Lock() + defer c.mtx.Unlock() - template, err := c.getTemplateNoLock(rawArtifacts) + template, err := c.getTemplateNoLock(rawArtifacts.Key()) if err != nil { - if IsMissingTemplateError(err) { + if errors.Is(err, ErrMissingConstraintTemplate) { return resp, nil } return resp, err @@ -456,16 +444,18 @@ func (c *Client) GetTemplate(templ *templates.ConstraintTemplate) (*templates.Co return nil, err } - c.constraintsMux.Lock() - defer c.constraintsMux.Unlock() - return c.getTemplateNoLock(artifacts) + c.mtx.RLock() + defer c.mtx.RUnlock() + return c.getTemplateNoLock(artifacts.Key()) } -func (c *Client) getTemplateNoLock(artifacts keyableArtifact) (*templates.ConstraintTemplate, error) { - t, ok := c.templates[artifacts.Key()] +func (c *Client) getTemplateNoLock(key templateKey) (*templates.ConstraintTemplate, error) { + t, ok := c.templates[key] if !ok { - return nil, NewMissingTemplateError(string(artifacts.Key())) + return nil, fmt.Errorf("%w: template for %q not found", + ErrMissingConstraintTemplate, key) } + ret := t.template.DeepCopy() return ret, nil } @@ -474,15 +464,20 @@ func (c *Client) getTemplateNoLock(artifacts keyableArtifact) (*templates.Constr // for each target: cluster.... func createConstraintSubPath(constraint *unstructured.Unstructured) (string, error) { if constraint.GetName() == "" { - return "", errors.New("invalid Constraint: missing name") + return "", fmt.Errorf("%w: missing name", ErrInvalidConstraint) } + gvk := constraint.GroupVersionKind() if gvk.Group == "" { - return "", fmt.Errorf("empty group for the constrant named %s", constraint.GetName()) + return "", fmt.Errorf("%w: empty group for constrant %q", + ErrInvalidConstraint, constraint.GetName()) } + if gvk.Kind == "" { - return "", fmt.Errorf("empty kind for the constraint named %s", constraint.GetName()) + return "", fmt.Errorf("%w: empty kind for constraint %q", + ErrInvalidConstraint, constraint.GetName()) } + return path.Join(createConstraintGKSubPath(gvk.GroupKind()), constraint.GetName()), nil } @@ -515,19 +510,31 @@ func constraintPathMerge(target, subpath string) string { func (c *Client) getTemplateEntry(constraint *unstructured.Unstructured, lock bool) (*templateEntry, error) { kind := constraint.GetKind() if kind == "" { - return nil, fmt.Errorf("kind missing from Constraint %q", constraint.GetName()) + return nil, fmt.Errorf("%w: kind missing from Constraint %q", + ErrInvalidConstraint, constraint.GetName()) } + if constraint.GroupVersionKind().Group != constraintGroup { - return nil, fmt.Errorf("wrong API Group for Constraint %q", constraint.GetName()) + return nil, fmt.Errorf("%w: wrong API Group for Constraint %q, need %q", + ErrInvalidConstraint, constraint.GetName(), constraintGroup) } + if lock { - c.constraintsMux.RLock() - defer c.constraintsMux.RUnlock() + c.mtx.RLock() + defer c.mtx.RUnlock() } + entry, ok := c.templates[templateKeyFromConstraint(constraint)] if !ok { - return nil, NewUnrecognizedConstraintError(kind) + var known []string + for k := range c.templates { + known = append(known, string(k)) + } + + return nil, fmt.Errorf("%w: Constraint kind %q is not recognized, known kinds %v", + ErrMissingConstraintTemplate, kind, known) } + return entry, nil } @@ -535,25 +542,30 @@ func (c *Client) getTemplateEntry(constraint *unstructured.Unstructured, lock bo // On error, the responses return value will still be populated so that // partial results can be analyzed. func (c *Client) AddConstraint(ctx context.Context, constraint *unstructured.Unstructured) (*types.Responses, error) { - c.constraintsMux.RLock() - defer c.constraintsMux.RUnlock() + c.mtx.Lock() + defer c.mtx.Unlock() + resp := types.NewResponses() errMap := make(ErrorMap) entry, err := c.getTemplateEntry(constraint, false) if err != nil { return resp, err } + subPath, err := createConstraintSubPath(constraint) if err != nil { - return resp, err + return resp, fmt.Errorf("creating Constraint subpath: %w", err) } + // return immediately if no change - if cached, err := c.getConstraintNoLock(constraint); err == nil && constraintlib.SemanticEqual(cached, constraint) { + cached, err := c.getConstraintNoLock(constraint) + if err == nil && constraintlib.SemanticEqual(cached, constraint) { for _, target := range entry.Targets { resp.Handled[target] = true } return resp, nil } + if err := c.validateConstraint(constraint, false); err != nil { return resp, err } @@ -571,18 +583,21 @@ func (c *Client) AddConstraint(ctx context.Context, constraint *unstructured.Uns } resp.Handled[target] = true } + if len(errMap) == 0 { c.constraints[constraint.GroupVersionKind().GroupKind()][subPath] = constraint.DeepCopy() return resp, nil } - return resp, errMap + + return resp, &errMap } // RemoveConstraint removes a constraint from OPA. On error, the responses // return value will still be populated so that partial results can be analyzed. func (c *Client) RemoveConstraint(ctx context.Context, constraint *unstructured.Unstructured) (*types.Responses, error) { - c.constraintsMux.RLock() - defer c.constraintsMux.RUnlock() + c.mtx.Lock() + defer c.mtx.Unlock() + return c.removeConstraintNoLock(ctx, constraint) } @@ -616,7 +631,7 @@ func (c *Client) removeConstraintNoLock(ctx context.Context, constraint *unstruc delete(c.constraints[constraint.GroupVersionKind().GroupKind()], subPath) return resp, nil } - return resp, errMap + return resp, &errMap } // getConstraintNoLock gets the currently recognized constraint without the lock. @@ -626,17 +641,20 @@ func (c *Client) getConstraintNoLock(constraint *unstructured.Unstructured) (*un return nil, err } - cstr, ok := c.constraints[constraint.GroupVersionKind().GroupKind()][subPath] + gk := constraint.GroupVersionKind().GroupKind() + cstr, ok := c.constraints[gk][subPath] if !ok { - return nil, NewMissingConstraintError(subPath) + return nil, fmt.Errorf("%w %v %q", + ErrMissingConstraint, gk, constraint.GetName()) } return cstr.DeepCopy(), nil } // GetConstraint gets the currently recognized constraint. func (c *Client) GetConstraint(constraint *unstructured.Unstructured) (*unstructured.Unstructured, error) { - c.constraintsMux.Lock() - defer c.constraintsMux.Unlock() + c.mtx.RLock() + defer c.mtx.RUnlock() + return c.getConstraintNoLock(constraint) } @@ -676,16 +694,18 @@ func (c *Client) init() error { return err } - moduleName := fmt.Sprintf("%s.hooks_builtin", hooks) - err := c.backend.driver.PutModule(moduleName, libBuiltin.String()) + builtinPath := fmt.Sprintf("%s.hooks_builtin", hooks) + err := c.backend.driver.PutModule(builtinPath, libBuiltin.String()) if err != nil { return err } libTempl := t.Library() if libTempl == nil { - return fmt.Errorf("target %q has no Rego library template", t.GetName()) + return fmt.Errorf("%w: target %q has no Rego library template", + ErrCreatingClient, t.GetName()) } + libBuf := &bytes.Buffer{} if err := libTempl.Execute(libBuf, map[string]string{ "ConstraintsRoot": fmt.Sprintf(`data.constraints["%s"].cluster["%s"]`, t.GetName(), constraintGroup), @@ -693,30 +713,41 @@ func (c *Client) init() error { }); err != nil { return err } + lib := libBuf.String() req := map[string]struct{}{ "autoreject_review": {}, "matching_reviews_and_constraints": {}, "matching_constraints": {}, } + modulePath := fmt.Sprintf("%s.library", hooks) libModule, err := parseModule(modulePath, lib) if err != nil { return fmt.Errorf("failed to parse module: %w", err) } - if err := requireRulesModule(libModule, req); err != nil { - return fmt.Errorf("problem with the below Rego for %q target:\n\n====%s\n====\n%s", t.GetName(), lib, err) + + err = requireModuleRules(libModule, req) + if err != nil { + return fmt.Errorf("problem with the below Rego for %q target:\n\n====%s\n====\n%w", + t.GetName(), lib, err) } + err = rewriteModulePackage(modulePath, libModule) if err != nil { return err } + src, err := format.Ast(libModule) if err != nil { - return fmt.Errorf("could not re-format Rego source: %v", err) + return fmt.Errorf("%w: could not re-format Rego source: %v", + ErrCreatingClient, err) } - if err := c.backend.driver.PutModule(modulePath, string(src)); err != nil { - return fmt.Errorf("error %s from compiled source:\n%s", err, src) + + err = c.backend.driver.PutModule(modulePath, string(src)) + if err != nil { + return fmt.Errorf("%w: error %s from compiled source:\n%s", + ErrCreatingClient, err, src) } } @@ -725,8 +756,9 @@ func (c *Client) init() error { // Reset the state of OPA. func (c *Client) Reset(ctx context.Context) error { - c.constraintsMux.Lock() - defer c.constraintsMux.Unlock() + c.mtx.Lock() + defer c.mtx.Unlock() + for name := range c.targets { if _, err := c.backend.driver.DeleteData(ctx, fmt.Sprintf("/external/%s", name)); err != nil { return err @@ -747,18 +779,6 @@ func (c *Client) Reset(ctx context.Context) error { return nil } -type queryCfg struct { - enableTracing bool -} - -type QueryOpt func(*queryCfg) - -func Tracing(enabled bool) QueryOpt { - return func(cfg *queryCfg) { - cfg.enableTracing = enabled - } -} - // Review makes sure the provided object satisfies all stored constraints. // On error, the responses return value will still be populated so that // partial results can be analyzed. @@ -798,7 +818,7 @@ TargetLoop: if len(errMap) == 0 { return responses, nil } - return responses, errMap + return responses, &errMap } // Audit makes sure the cached state of the system satisfies all stored constraints. @@ -831,10 +851,21 @@ TargetLoop: if len(errMap) == 0 { return responses, nil } - return responses, errMap + return responses, &errMap } // Dump dumps the state of OPA to aid in debugging. func (c *Client) Dump(ctx context.Context) (string, error) { return c.backend.driver.Dump(ctx) } + +// knownTargets returns a sorted list of currently-known target names. +func (c *Client) knownTargets() []string { + var knownTargets []string + for known := range c.targets { + knownTargets = append(knownTargets, known) + } + sort.Strings(knownTargets) + + return knownTargets +} diff --git a/constraint/pkg/client/client_test.go b/constraint/pkg/client/client_test.go index 5ff15366e..e1ab7e44c 100644 --- a/constraint/pkg/client/client_test.go +++ b/constraint/pkg/client/client_test.go @@ -3,17 +3,19 @@ package client import ( "context" "errors" - "reflect" "strings" "testing" "text/template" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" - constraintlib "github.com/open-policy-agent/frameworks/constraint/pkg/core/constraints" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" "github.com/open-policy-agent/frameworks/constraint/pkg/types" "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/utils/pointer" ) const badRego = `asd{` @@ -43,11 +45,10 @@ matching_reviews_and_constraints[[r,c]] {r = data.r; c = data.c}`)) } func (h *badHandler) MatchSchema() apiextensions.JSONSchemaProps { - trueBool := true - return apiextensions.JSONSchemaProps{XPreserveUnknownFields: &trueBool} + return apiextensions.JSONSchemaProps{XPreserveUnknownFields: pointer.Bool(true)} } -func (h *badHandler) ProcessData(obj interface{}) (bool, string, interface{}, error) { +func (h *badHandler) ProcessData(_ interface{}) (bool, string, interface{}, error) { if h.Errors { return false, "", nil, errors.New("some error") } @@ -57,245 +58,241 @@ func (h *badHandler) ProcessData(obj interface{}) (bool, string, interface{}, er return true, "projects/something", nil, nil } -func (h *badHandler) HandleReview(obj interface{}) (bool, interface{}, error) { +func (h *badHandler) HandleReview(_ interface{}) (bool, interface{}, error) { return false, "", nil } -func (h *badHandler) HandleViolation(result *types.Result) error { +func (h *badHandler) HandleViolation(_ *types.Result) error { return nil } -func (h *badHandler) ValidateConstraint(u *unstructured.Unstructured) error { +func (h *badHandler) ValidateConstraint(_ *unstructured.Unstructured) error { return nil } func TestInvalidTargetName(t *testing.T) { - tc := []struct { - Name string - Handler TargetHandler - ErrorExpected bool + tcs := []struct { + name string + handler TargetHandler + wantError error }{ { - Name: "Acceptable Name", - Handler: &badHandler{Name: "Hello8", HasLib: true}, - ErrorExpected: false, + name: "Acceptable name", + handler: &badHandler{Name: "Hello8", HasLib: true}, + wantError: nil, }, { - Name: "No Name", - Handler: &badHandler{Name: ""}, - ErrorExpected: true, + name: "No name", + handler: &badHandler{Name: ""}, + wantError: ErrCreatingClient, }, { - Name: "No Dots", - Handler: &badHandler{Name: "asdf.asdf"}, - ErrorExpected: true, + name: "Dots not allowed", + handler: &badHandler{Name: "asdf.asdf"}, + wantError: ErrCreatingClient, }, { - Name: "No Spaces", - Handler: &badHandler{Name: "asdf asdf"}, - ErrorExpected: true, + name: "Spaces not allowed", + handler: &badHandler{Name: "asdf asdf"}, + wantError: ErrCreatingClient, }, { - Name: "Must start with a letter", - Handler: &badHandler{Name: "8asdf"}, - ErrorExpected: true, + name: "Must start with a letter", + handler: &badHandler{Name: "8asdf"}, + wantError: ErrCreatingClient, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - _, err = b.NewClient(Targets(tt.Handler)) - if (err == nil) && tt.ErrorExpected { - t.Fatalf("err = nil; want non-nil") - } - if (err != nil) && !tt.ErrorExpected { - t.Fatalf("err = \"%s\"; want nil", err) + + _, err = b.NewClient(Targets(tc.handler)) + if !errors.Is(err, tc.wantError) { + t.Errorf("got NewClient() error = %v, want %v", + err, tc.wantError) } }) } } func TestAddData(t *testing.T) { - tc := []struct { - Name string - Handler1 TargetHandler - Handler2 TargetHandler - ErroredBy []string - HandledBy []string + tcs := []struct { + name string + handler1 TargetHandler + handler2 TargetHandler + wantHandled map[string]bool + wantError map[string]bool }{ { - Name: "Handled By Both", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, - HandledBy: []string{"h1", "h2"}, + name: "Handled By Both", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, + wantHandled: map[string]bool{"h1": true, "h2": true}, + wantError: nil, }, { - Name: "Handled By One", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - HandledBy: []string{"h1"}, + name: "Handled By One", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Errored By One", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - HandledBy: []string{"h1"}, - ErroredBy: []string{"h2"}, + name: "Errored By One", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + wantHandled: map[string]bool{"h1": true}, + wantError: map[string]bool{"h2": true}, }, { - Name: "Errored By Both", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - ErroredBy: []string{"h1", "h2"}, + name: "Errored By Both", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + wantError: map[string]bool{"h1": true, "h2": true}, }, { - Name: "Handled By None", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + name: "Handled By None", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + wantHandled: nil, + wantError: nil, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler1, tt.Handler2)) + c, err := b.NewClient(Targets(tc.handler1, tc.handler2)) if err != nil { t.Fatal(err) } - r, err := c.AddData(ctx, nil) - if err != nil && len(tt.ErroredBy) == 0 { + + r, err := c.AddData(context.Background(), nil) + if err != nil && len(tc.wantError) == 0 { t.Fatalf("err = %s; want nil", err) } - expectedErr := make(map[string]bool) - actualErr := make(map[string]bool) - for _, v := range tt.ErroredBy { - expectedErr[v] = true - } - if e, ok := err.(ErrorMap); ok { - for k := range e { - actualErr[k] = true + gotErrs := make(map[string]bool) + if e, ok := err.(*ErrorMap); ok { + for k := range *e { + gotErrs[k] = true } } - if !reflect.DeepEqual(actualErr, expectedErr) { - t.Errorf("errSet = %v; wanted %v", actualErr, expectedErr) - } - expectedHandled := make(map[string]bool) - for _, v := range tt.HandledBy { - expectedHandled[v] = true + + if diff := cmp.Diff(tc.wantError, gotErrs, cmpopts.EquateEmpty()); diff != "" { + t.Errorf(diff) } + if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("handledSet = %v; wanted %v", r.Handled, expectedHandled) - } - if r.HandledCount() != len(expectedHandled) { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), len(expectedHandled)) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } } func TestRemoveData(t *testing.T) { - tc := []struct { - Name string - Handler1 TargetHandler - Handler2 TargetHandler - ErroredBy []string - HandledBy []string + tcs := []struct { + name string + handler1 TargetHandler + handler2 TargetHandler + wantHandled map[string]bool + wantError map[string]bool }{ { - Name: "Handled By Both", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, - HandledBy: []string{"h1", "h2"}, + name: "Handled By Both", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true}, + wantHandled: map[string]bool{"h1": true, "h2": true}, + wantError: nil, }, { - Name: "Handled By One", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, - HandledBy: []string{"h1"}, + name: "Handled By One", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Errored By One", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - HandledBy: []string{"h1"}, - ErroredBy: []string{"h2"}, + name: "Errored By One", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + wantHandled: map[string]bool{"h1": true}, + wantError: map[string]bool{"h2": true}, }, { - Name: "Errored By Both", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, - ErroredBy: []string{"h1", "h2"}, + name: "Errored By Both", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: true, Errors: true}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: true, Errors: true}, + wantHandled: nil, + wantError: map[string]bool{"h1": true, "h2": true}, }, { - Name: "Handled By None", - Handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, - Handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + name: "Handled By None", + handler1: &badHandler{Name: "h1", HasLib: true, HandlesData: false}, + handler2: &badHandler{Name: "h2", HasLib: true, HandlesData: false}, + wantHandled: nil, + wantError: nil, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler1, tt.Handler2)) + c, err := b.NewClient(Targets(tc.handler1, tc.handler2)) if err != nil { t.Fatal(err) } - r, err := c.RemoveData(ctx, nil) - if err != nil && len(tt.ErroredBy) == 0 { + + r, err := c.RemoveData(context.Background(), nil) + if err != nil && len(tc.wantError) == 0 { t.Fatalf("err = %s; want nil", err) } - expectedErr := make(map[string]bool) - actualErr := make(map[string]bool) - for _, v := range tt.ErroredBy { - expectedErr[v] = true - } - if e, ok := err.(ErrorMap); ok { - for k := range e { - actualErr[k] = true + + gotErrs := make(map[string]bool) + if e, ok := err.(*ErrorMap); ok { + for k := range *e { + gotErrs[k] = true } } - if !reflect.DeepEqual(actualErr, expectedErr) { - t.Errorf("errSet = %v; wanted %v", actualErr, expectedErr) - } - expectedHandled := make(map[string]bool) - for _, v := range tt.HandledBy { - expectedHandled[v] = true + + if diff := cmp.Diff(tc.wantError, gotErrs, cmpopts.EquateEmpty()); diff != "" { + t.Errorf(diff) } if r == nil { t.Fatal("got RemoveData() == nil, want non-nil") } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("handledSet = %v; wanted %v", r.Handled, expectedHandled) - } - if r.HandledCount() != len(expectedHandled) { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), len(expectedHandled)) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } } -func TestAddTemplate(t *testing.T) { +func TestClient_AddTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego @@ -310,100 +307,97 @@ some_rule[r] { emptyRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) emptyRegoTempl.Spec.Targets[0].Rego = "" - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool + tcs := []struct { + name string + handler TargetHandler + template *templates.ConstraintTemplate + wantHandled map[string]bool + wantError error }{ { - Name: "Good Template", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), - ErrorExpected: false, + name: "Good Template", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Unknown Target", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - ErrorExpected: true, + name: "Unknown Target", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Bad CRD", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fakes"), targets("h1")), - ErrorExpected: true, + name: "Bad CRD", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fakes"), targets("h1")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "No Name", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(crdNames("Fake"), targets("h1")), - ErrorExpected: true, + name: "No name", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(crdNames("Fake"), targets("h1")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Bad Rego", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: badRegoTempl, - ErrorExpected: true, + name: "Bad Rego", + handler: &badHandler{Name: "h1", HasLib: true}, + template: badRegoTempl, + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "No Rego", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: emptyRegoTempl, - ErrorExpected: true, + name: "No Rego", + handler: &badHandler{Name: "h1", HasLib: true}, + template: emptyRegoTempl, + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Missing Rule", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: missingRuleTempl, - ErrorExpected: true, + name: "Missing Rule", + handler: &badHandler{Name: "h1", HasLib: true}, + template: missingRuleTempl, + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler)) + c, err := b.NewClient(Targets(tc.handler)) if err != nil { t.Fatal(err) } - r, err := c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Fatalf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + r, err := c.AddTemplate(tc.template) + if !errors.Is(err, tc.wantError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantError) } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} - } if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) - } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) - } - cached, err := c.GetTemplate(tt.Template) - if err == nil && tt.ErrorExpected { - t.Fatal("retrieved template when error was expected") + if diff := cmp.Diff(r.Handled, tc.wantHandled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } - if tt.ErrorExpected { + cached, err := c.GetTemplate(tc.template) + if tc.wantError != nil { + if err == nil { + t.Fatalf("got GetTemplate() error = %v, want non-nil", err) + } return } @@ -411,17 +405,20 @@ some_rule[r] { t.Fatalf("could not retrieve template when error was expected: %v", err) } - if !cached.SemanticEqual(tt.Template) { + if !cached.SemanticEqual(tc.template) { t.Error("cached template does not equal stored template") } - r2, err := c.RemoveTemplate(ctx, tt.Template) + + r2, err := c.RemoveTemplate(context.Background(), tc.template) if err != nil { t.Fatal("could not remove template") } + if r2.HandledCount() != 1 { t.Error("more targets handled than expected") } - if _, err := c.GetTemplate(tt.Template); err == nil { + + if _, err := c.GetTemplate(tc.template); err == nil { t.Error("template not cleared from cache") } }) @@ -431,67 +428,62 @@ some_rule[r] { func TestRemoveTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool + tcs := []struct { + name string + handler TargetHandler + template *templates.ConstraintTemplate + wantHandled map[string]bool + wantError error }{ { - Name: "Good Template", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - ErrorExpected: false, + name: "Good Template", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Unknown Target", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - ErrorExpected: true, + name: "Unknown Target", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Bad CRD", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), targets("h1")), - ErrorExpected: true, + name: "Bad CRD", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), targets("h1")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() - + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler)) + c, err := b.NewClient(Targets(tc.handler)) if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + _, err = c.AddTemplate(tc.template) + if !errors.Is(err, tc.wantError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantError) } - r, err := c.RemoveTemplate(ctx, tt.Template) + + r, err := c.RemoveTemplate(context.Background(), tc.template) if err != nil { t.Errorf("err = %v; want nil", err) } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} - } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) - } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } @@ -500,69 +492,66 @@ func TestRemoveTemplate(t *testing.T) { func TestRemoveTemplateByNameOnly(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool + tcs := []struct { + name string + handler TargetHandler + template *templates.ConstraintTemplate + wantHandled map[string]bool + wantError error }{ { - Name: "Good Template", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - ErrorExpected: false, + name: "Good Template", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Unknown Target", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - ErrorExpected: true, + name: "Unknown Target", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Bad CRD", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), targets("h1")), - ErrorExpected: true, + name: "Bad CRD", + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fake"), targets("h1")), + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler)) + c, err := b.NewClient(Targets(tc.handler)) if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + _, err = c.AddTemplate(tc.template) + if !errors.Is(err, tc.wantError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantError) } + sparseTemplate := &templates.ConstraintTemplate{} - sparseTemplate.Name = tt.Template.Name - r, err := c.RemoveTemplate(ctx, sparseTemplate) + sparseTemplate.Name = tc.template.Name + + r, err := c.RemoveTemplate(context.Background(), sparseTemplate) if err != nil { t.Errorf("err = %v; want nil", err) } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} - } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) - } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } @@ -571,61 +560,69 @@ func TestRemoveTemplateByNameOnly(t *testing.T) { func TestGetTemplate(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool + + tcs := []struct { + name string + handler TargetHandler + wantTemplate *templates.ConstraintTemplate + wantAddError error + wantGetError error }{ { - Name: "Good Template", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - ErrorExpected: false, + name: "Good Template", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + wantAddError: nil, + wantGetError: nil, }, { - Name: "Unknown Target", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - ErrorExpected: true, + name: "Unknown Target", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + wantAddError: ErrInvalidConstraintTemplate, + wantGetError: ErrMissingConstraintTemplate, }, { - Name: "Bad CRD", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), targets("h1")), - ErrorExpected: true, + name: "Bad CRD", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), targets("h1")), + wantAddError: ErrInvalidConstraintTemplate, + wantGetError: ErrMissingConstraintTemplate, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler)) + c, err := b.NewClient(Targets(tc.handler)) if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + _, err = c.AddTemplate(tc.wantTemplate) + if !errors.Is(err, tc.wantAddError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantAddError) } - tmpl, err := c.GetTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) + + gotTemplate, err := c.GetTemplate(tc.wantTemplate) + if !errors.Is(err, tc.wantGetError) { + t.Fatalf("got GetTemplate() error = %v, want %v", + err, tc.wantGetError) } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + if tc.wantAddError != nil { + return } - if !tt.ErrorExpected { - if !reflect.DeepEqual(tmpl, tt.Template) { - t.Error("Stored and retrieved template differ") - } + + if diff := cmp.Diff(tc.wantTemplate, gotTemplate); diff != "" { + t.Error(diff) } }) } @@ -634,71 +631,78 @@ func TestGetTemplate(t *testing.T) { func TestGetTemplateByNameOnly(t *testing.T) { badRegoTempl := createTemplate(name("fake"), crdNames("Fake"), targets("h1")) badRegoTempl.Spec.Targets[0].Rego = badRego - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool + + tcs := []struct { + name string + handler TargetHandler + wantTemplate *templates.ConstraintTemplate + wantAddError error + wantGetError error }{ { - Name: "Good Template", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), - ErrorExpected: false, + name: "Good Template", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h1")), + wantAddError: nil, + wantGetError: nil, }, { - Name: "Unknown Target", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), - ErrorExpected: true, + name: "Unknown Target", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), crdNames("Fake"), targets("h2")), + wantAddError: ErrInvalidConstraintTemplate, + wantGetError: ErrMissingConstraintTemplate, }, { - Name: "Bad CRD", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fake"), targets("h1")), - ErrorExpected: true, + name: "Bad CRD", + handler: &badHandler{Name: "h1", HasLib: true}, + wantTemplate: createTemplate(name("fake"), targets("h1")), + wantAddError: ErrInvalidConstraintTemplate, + wantGetError: ErrMissingConstraintTemplate, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } - c, err := b.NewClient(Targets(tt.Handler)) + c, err := b.NewClient(Targets(tc.handler)) if err != nil { t.Fatal(err) } - _, err = c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + _, err = c.AddTemplate(tc.wantTemplate) + if !errors.Is(err, tc.wantAddError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantAddError) } + sparseTemplate := &templates.ConstraintTemplate{} - sparseTemplate.Name = tt.Template.Name - tmpl, err := c.GetTemplate(sparseTemplate) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) + sparseTemplate.Name = tc.wantTemplate.Name + + gotTemplate, err := c.GetTemplate(sparseTemplate) + if !errors.Is(err, tc.wantGetError) { + t.Fatalf("Got GetTemplate() error = %v, want %v", + err, tc.wantGetError) } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") + + if tc.wantGetError != nil { + return } - if !tt.ErrorExpected { - if !reflect.DeepEqual(tmpl, tt.Template) { - t.Error("Stored and retrieved template differ") - } + + if diff := cmp.Diff(tc.wantTemplate, gotTemplate); diff != "" { + t.Error(diff) } }) } } func TestTemplateCascadingDelete(t *testing.T) { - ctx := context.Background() - handler := &badHandler{Name: "h1", HasLib: true} d := local.New() @@ -711,17 +715,19 @@ func TestTemplateCascadingDelete(t *testing.T) { if err != nil { t.Fatal(err) } + templ := createTemplate(name("cascadingdelete"), crdNames("CascadingDelete"), targets("h1")) if _, err = c.AddTemplate(templ); err != nil { t.Errorf("err = %v; want nil", err) } cst1 := newConstraint("CascadingDelete", "cascadingdelete", nil, nil) - if _, err = c.AddConstraint(ctx, cst1); err != nil { + if _, err = c.AddConstraint(context.Background(), cst1); err != nil { t.Error("could not add first constraint") } + cst2 := newConstraint("CascadingDelete", "cascadingdelete2", nil, nil) - if _, err = c.AddConstraint(ctx, cst2); err != nil { + if _, err = c.AddConstraint(context.Background(), cst2); err != nil { t.Error("could not add second constraint") } @@ -731,15 +737,16 @@ func TestTemplateCascadingDelete(t *testing.T) { } cst3 := newConstraint("StillPersists", "stillpersists", nil, nil) - if _, err = c.AddConstraint(ctx, cst3); err != nil { + if _, err = c.AddConstraint(context.Background(), cst3); err != nil { t.Error("could not add third constraint") } + cst4 := newConstraint("StillPersists", "stillpersists2", nil, nil) - if _, err = c.AddConstraint(ctx, cst4); err != nil { + if _, err = c.AddConstraint(context.Background(), cst4); err != nil { t.Error("could not add fourth constraint") } - orig, err := c.Dump(ctx) + orig, err := c.Dump(context.Background()) if err != nil { t.Errorf("could not dump original state: %s", err) } @@ -755,21 +762,24 @@ func TestTemplateCascadingDelete(t *testing.T) { t.Errorf("preservation candidate not cached: %s", orig) } - if _, err = c.RemoveTemplate(ctx, templ); err != nil { + if _, err = c.RemoveTemplate(context.Background(), templ); err != nil { t.Error("could not remove template") } + if len(c.constraints) != 1 { t.Errorf("constraint cache expected to have only 1 entry: %+v", c.constraints) } - s, err := c.Dump(ctx) + s, err := c.Dump(context.Background()) if err != nil { t.Errorf("could not dump OPA cache") } + sLower := strings.ToLower(s) if strings.Contains(sLower, "cascadingdelete") { t.Errorf("Template not removed from cache: %s", s) } + finalPreserved := strings.Count(sLower, "stillpersists") if finalPreserved != origPreserved { t.Errorf("finalPreserved = %d, expected %d :: %s", finalPreserved, origPreserved, s) @@ -777,146 +787,169 @@ func TestTemplateCascadingDelete(t *testing.T) { } func TestAddConstraint(t *testing.T) { - tc := []struct { - Name string - Constraint *unstructured.Unstructured - OmitTemplate bool - ErrorExpected bool + handler := &badHandler{Name: "h1", HasLib: true} + + tcs := []struct { + name string + template *templates.ConstraintTemplate + constraint *unstructured.Unstructured + wantHandled map[string]bool + wantAddConstraintError error + wantGetConstraintError error }{ { - Name: "Good Constraint", - Constraint: newConstraint("Foos", "foo", nil, nil), + name: "Good Constraint", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("Foos", "foo", nil, nil), + wantHandled: map[string]bool{"h1": true}, + wantAddConstraintError: nil, + wantGetConstraintError: nil, }, { - Name: "No Name", - Constraint: newConstraint("Foos", "", nil, nil), - ErrorExpected: true, + name: "No Name", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("Foos", "", nil, nil), + wantHandled: nil, + wantAddConstraintError: ErrInvalidConstraint, + wantGetConstraintError: ErrInvalidConstraint, }, { - Name: "No Kind", - Constraint: newConstraint("", "foo", nil, nil), - ErrorExpected: true, + name: "No Kind", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("", "foo", nil, nil), + wantHandled: nil, + wantAddConstraintError: ErrInvalidConstraint, + wantGetConstraintError: ErrInvalidConstraint, }, { - Name: "No Template", - Constraint: newConstraint("Foo", "foo", nil, nil), - OmitTemplate: true, - ErrorExpected: true, + name: "No Template", + template: nil, + constraint: newConstraint("Foo", "foo", nil, nil), + wantHandled: nil, + wantAddConstraintError: ErrMissingConstraintTemplate, + wantGetConstraintError: ErrMissingConstraint, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { - ctx := context.Background() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { - t.Fatalf("Could not create backend: %s", err) + t.Fatal(err) } - handler := &badHandler{Name: "h1", HasLib: true} c, err := b.NewClient(Targets(handler)) if err != nil { t.Fatal(err) } - if !tt.OmitTemplate { - tmpl := createTemplate(name("foos"), crdNames("Foos"), targets("h1")) - _, err := c.AddTemplate(tmpl) + + if tc.template != nil { + _, err = c.AddTemplate(tc.template) if err != nil { t.Fatal(err) } } - r, err := c.AddConstraint(ctx, tt.Constraint) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") - } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} + + r, err := c.AddConstraint(context.Background(), tc.constraint) + if !errors.Is(err, tc.wantAddConstraintError) { + t.Fatalf("got AddConstraint() error = %v, want %v", + err, tc.wantAddConstraintError) } if r == nil { t.Fatal("got AddConstraint() == nil, want non-nil") } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + + cached, err := c.GetConstraint(tc.constraint) + if !errors.Is(err, tc.wantGetConstraintError) { + t.Fatalf("got GetConstraint() error = %v, want %v", + err, tc.wantGetConstraintError) } - cached, err := c.GetConstraint(tt.Constraint) - if err == nil && tt.ErrorExpected { - t.Error("retrieved constraint when error was expected") + + if tc.wantGetConstraintError != nil { + return } - if err != nil && !tt.ErrorExpected { - t.Error("could not retrieve constraint when error was expected") + + if diff := cmp.Diff(tc.constraint.Object["spec"], cached.Object["spec"]); diff != "" { + t.Error("cached constraint does not equal stored constraint") } - if !tt.ErrorExpected { - if !constraintlib.SemanticEqual(cached, tt.Constraint) { - t.Error("cached constraint does not equal stored constraint") - } - r2, err := c.RemoveConstraint(ctx, tt.Constraint) - if err != nil { - t.Error("could not remove constraint") - } - if r2 == nil { - t.Fatal("got RemoveConstraint() == nil, want non-nil") - } - if r2.HandledCount() != 1 { - t.Error("more targets handled than expected") - } - if _, err := c.GetConstraint(tt.Constraint); err == nil { - t.Error("constraint not cleared from cache") - } + r2, err := c.RemoveConstraint(context.Background(), tc.constraint) + if err != nil { + t.Error("could not remove constraint") + } + + if r2 == nil { + t.Fatal("got RemoveConstraint() == nil, want non-nil") + } + + if r2.HandledCount() != 1 { + t.Error("more targets handled than expected") + } + + if _, err := c.GetConstraint(tc.constraint); err == nil { + t.Error("constraint not cleared from cache") } }) } } func TestRemoveConstraint(t *testing.T) { - tc := []struct { - Name string - Constraint *unstructured.Unstructured - OmitTemplate bool - ErrorExpected bool - ExpectedErrorType string + tcs := []struct { + name string + template *templates.ConstraintTemplate + constraint *unstructured.Unstructured + toRemove *unstructured.Unstructured + wantHandled map[string]bool + wantError error }{ { - Name: "Good Constraint", - Constraint: newConstraint("Foos", "foo", nil, nil), + name: "Good Constraint", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("Foos", "foo", nil, nil), + toRemove: newConstraint("Foos", "foo", nil, nil), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "No Name", - Constraint: newConstraint("Foos", "", nil, nil), - ErrorExpected: true, + name: "No name", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("Foos", "foo", nil, nil), + toRemove: newConstraint("Foos", "", nil, nil), + wantHandled: nil, + wantError: ErrInvalidConstraint, }, { - Name: "No Kind", - Constraint: newConstraint("", "foo", nil, nil), - ErrorExpected: true, + name: "No Kind", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + constraint: newConstraint("Foos", "foo", nil, nil), + toRemove: newConstraint("", "foo", nil, nil), + wantHandled: nil, + wantError: ErrInvalidConstraint, }, { - Name: "No Template", - Constraint: newConstraint("Foo", "foo", nil, nil), - OmitTemplate: true, - ErrorExpected: true, + name: "No Template", + toRemove: newConstraint("Foos", "foo", nil, nil), + wantHandled: nil, + wantError: ErrMissingConstraintTemplate, }, { - Name: "Unrecognized Constraint", - Constraint: newConstraint("Bar", "bar", nil, nil), - OmitTemplate: true, - ErrorExpected: true, - ExpectedErrorType: "*client.UnrecognizedConstraintError", + name: "No Constraint", + template: createTemplate(name("foos"), crdNames("Foos"), targets("h1")), + toRemove: newConstraint("Foos", "bar", nil, nil), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { ctx := context.Background() d := local.New() @@ -930,38 +963,34 @@ func TestRemoveConstraint(t *testing.T) { if err != nil { t.Fatal(err) } - if !tt.OmitTemplate { - tmpl := createTemplate(name("foos"), crdNames("Foos"), targets("h1")) - _, err := c.AddTemplate(tmpl) + + if tc.template != nil { + _, err = c.AddTemplate(tc.template) if err != nil { t.Fatal(err) } } - r, err := c.RemoveConstraint(ctx, tt.Constraint) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") - } - if tt.ErrorExpected && tt.ExpectedErrorType != "" && reflect.TypeOf(err).String() != tt.ExpectedErrorType { - t.Errorf("err type = %s; want %s", reflect.TypeOf(err).String(), tt.ExpectedErrorType) + + if tc.constraint != nil { + _, err = c.AddConstraint(ctx, tc.constraint) + if err != nil { + t.Fatal(err) + } } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} + + r, err := c.RemoveConstraint(context.Background(), tc.toRemove) + + if !errors.Is(err, tc.wantError) { + t.Errorf("got RemoveConstraint error = %v, want %v", + err, tc.wantError) } if r == nil { t.Fatal("got RemoveConstraint() == nil, want non-nil") } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) - } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } @@ -977,128 +1006,383 @@ violation[{"msg": "msg"}] { } ` - tc := []struct { - Name string - Handler TargetHandler - Template *templates.ConstraintTemplate - ErrorExpected bool - InvAllowed bool + tcs := []struct { + name string + allowedFields []string + handler TargetHandler + template *templates.ConstraintTemplate + wantHandled map[string]bool + wantError error }{ { - Name: "Inventory Not Used", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), - ErrorExpected: false, + name: "Inventory Not Used", + allowedFields: []string{}, + handler: &badHandler{Name: "h1", HasLib: true}, + template: createTemplate(name("fakes"), crdNames("Fakes"), targets("h1")), + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, { - Name: "Inventory Used", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: inventoryTempl, - ErrorExpected: true, + name: "Inventory used but not allowed", + allowedFields: []string{}, + handler: &badHandler{Name: "h1", HasLib: true}, + template: inventoryTempl, + wantHandled: nil, + wantError: ErrInvalidConstraintTemplate, }, { - Name: "Inventory Used But Allowed", - Handler: &badHandler{Name: "h1", HasLib: true}, - Template: inventoryTempl, - ErrorExpected: false, - InvAllowed: true, + name: "Inventory used and allowed", + allowedFields: []string{"inventory"}, + handler: &badHandler{Name: "h1", HasLib: true}, + template: inventoryTempl, + wantHandled: map[string]bool{"h1": true}, + wantError: nil, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - f := AllowedDataFields() - if tt.InvAllowed { - f = AllowedDataFields("inventory") + t.Fatal(err) } - c, err := b.NewClient(Targets(tt.Handler), f) + c, err := b.NewClient(Targets(tc.handler), AllowedDataFields(tc.allowedFields...)) if err != nil { t.Fatal(err) } - r, err := c.AddTemplate(tt.Template) - if err != nil && !tt.ErrorExpected { - t.Errorf("err = %v; want nil", err) - } - if err == nil && tt.ErrorExpected { - t.Error("err = nil; want non-nil") - } - expectedCount := 0 - expectedHandled := make(map[string]bool) - if !tt.ErrorExpected { - expectedCount = 1 - expectedHandled = map[string]bool{"h1": true} + + r, err := c.AddTemplate(tc.template) + if !errors.Is(err, tc.wantError) { + t.Fatalf("got AddTemplate() error = %v, want %v", + err, tc.wantError) } if r == nil { t.Fatal("got AddTemplate() == nil, want non-nil") } - if r.HandledCount() != expectedCount { - t.Errorf("HandledCount() = %v; want %v", r.HandledCount(), expectedCount) - } - if !reflect.DeepEqual(r.Handled, expectedHandled) { - t.Errorf("r.Handled = %v; want %v", r.Handled, expectedHandled) + + if diff := cmp.Diff(tc.wantHandled, r.Handled, cmpopts.EquateEmpty()); diff != "" { + t.Error(diff) } }) } } func TestAllowedDataFieldsIntersection(t *testing.T) { - tc := []struct { - Name string - Allowed Opt - Expected []string - wantError bool + tcs := []struct { + name string + allowed Opt + want []string + wantError error }{ { - Name: "No AllowedDataFields specified", - Expected: []string{"inventory"}, + name: "No AllowedDataFields specified", + want: []string{"inventory"}, }, { - Name: "Empty AllowedDataFields Used", - Allowed: AllowedDataFields(), - Expected: nil, + name: "Empty AllowedDataFields Used", + allowed: AllowedDataFields(), + want: nil, }, { - Name: "Inventory Used", - Allowed: AllowedDataFields("inventory"), - Expected: []string{"inventory"}, + name: "Inventory Used", + allowed: AllowedDataFields("inventory"), + want: []string{"inventory"}, }, { - Name: "Invalid Data Field", - Allowed: AllowedDataFields("no_overlap"), - Expected: []string{}, - wantError: true, + name: "Invalid Data Field", + allowed: AllowedDataFields("no_overlap"), + want: []string{}, + wantError: ErrCreatingClient, }, } - for _, tt := range tc { - t.Run(tt.Name, func(t *testing.T) { + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { d := local.New() + b, err := NewBackend(Driver(d)) if err != nil { t.Fatalf("Could not create backend: %s", err) } + opts := []Opt{Targets(&badHandler{Name: "h1", HasLib: true})} - if tt.Allowed != nil { - opts = append(opts, tt.Allowed) + if tc.allowed != nil { + opts = append(opts, tc.allowed) } c, err := b.NewClient(opts...) - if tt.wantError { - if err == nil { - t.Fatalf("Expectd error, got nil") - } + if !errors.Is(err, tc.wantError) { + t.Fatalf("got NewClient() error = %v, want %v", + err, tc.wantError) + } + + if tc.wantError != nil { return } + + if diff := cmp.Diff(tc.want, c.allowedDataFields); diff != "" { + t.Error(diff) + } + }) + } +} + +func TestClient_CreateCRD(t *testing.T) { + testCases := []struct { + name string + targets []TargetHandler + template *templates.ConstraintTemplate + want *apiextensions.CustomResourceDefinition + wantErr error + }{ + { + name: "nil", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: nil, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "empty", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{}, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "no CRD kind", + targets: []TargetHandler{&handler{}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "name-kind mismatch", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Bar", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler", + Rego: `package foo + +violation[msg] {msg := "always"}`, + }}, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "no targets", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "wrong target", + targets: []TargetHandler{&badHandler{Name: "handler.1", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler.2", + }}, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "no rego", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler", + }}, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "empty rego package", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler", + Rego: `package foo`, + }}, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "multiple targets", + targets: []TargetHandler{ + &badHandler{Name: "handler", HasLib: true}, + &badHandler{Name: "handler.2", HasLib: true}, + }, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler", + Rego: `package foo + +violation[msg] {msg := "always"}`, + }, { + Target: "handler.2", + Rego: `package foo + +violation[msg] {msg := "always"}`, + }}, + }, + }, + want: nil, + wantErr: ErrInvalidConstraintTemplate, + }, + { + name: "minimal working", + targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, + template: &templates.ConstraintTemplate{ + ObjectMeta: v1.ObjectMeta{Name: "foo"}, + Spec: templates.ConstraintTemplateSpec{ + CRD: templates.CRD{ + Spec: templates.CRDSpec{ + Names: templates.Names{ + Kind: "Foo", + }, + }, + }, + Targets: []templates.Target{{ + Target: "handler", + Rego: `package foo + +violation[msg] {msg := "always"}`, + }}, + }, + }, + want: &apiextensions.CustomResourceDefinition{ + ObjectMeta: v1.ObjectMeta{ + Name: "foo.constraints.gatekeeper.sh", + Labels: map[string]string{"gatekeeper.sh/constraint": "yes"}, + }, + Spec: apiextensions.CustomResourceDefinitionSpec{ + Group: "constraints.gatekeeper.sh", + Version: "v1beta1", + Names: apiextensions.CustomResourceDefinitionNames{ + Plural: "foo", + Singular: "foo", + Kind: "Foo", + ListKind: "FooList", + Categories: []string{"constraint", "constraints"}, + }, + Scope: apiextensions.ClusterScoped, + Subresources: &apiextensions.CustomResourceSubresources{ + Status: &apiextensions.CustomResourceSubresourceStatus{}, + }, + Versions: []apiextensions.CustomResourceDefinitionVersion{{ + Name: "v1beta1", Served: true, Storage: true, + }, { + Name: "v1alpha1", Served: true, + }}, + Conversion: &apiextensions.CustomResourceConversion{ + Strategy: apiextensions.NoneConverter, + }, + PreserveUnknownFields: pointer.BoolPtr(false), + }, + Status: apiextensions.CustomResourceDefinitionStatus{ + StoredVersions: []string{"v1beta1"}, + }, + }, + wantErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + d := local.New() + + b, err := NewBackend(Driver(d)) + if err != nil { + t.Fatal(err) + } + + c, err := b.NewClient(Targets(tc.targets...)) if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(c.allowedDataFields, tt.Expected) { - t.Errorf("c.allowedDataFields = %v; want %v", c.allowedDataFields, tt.Expected) + + t.Log(c.targets) + + got, err := c.CreateCRD(tc.template) + + if !errors.Is(err, tc.wantErr) { + t.Fatalf("got CreateTemplate() error = %v, want %v", + err, tc.wantErr) + } + + if diff := cmp.Diff(tc.want, got, + cmpopts.IgnoreFields(apiextensions.CustomResourceDefinitionSpec{}, "Validation")); diff != "" { + t.Error(diff) } }) }