Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
orestisfl committed Nov 27, 2023
1 parent 0929dbd commit 14bf2b6
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 6 deletions.
14 changes: 8 additions & 6 deletions resources/providers/azurelib/governance/management_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package governance

// TODO: tests

import (
"context"
"fmt"
Expand Down Expand Up @@ -67,13 +65,15 @@ type provider struct {

func NewProvider(log *logp.Logger, client inventory.ProviderAPI) ProviderAPI {
return &provider{
log: log.Named("governance"),
client: client,
log: log.Named("governance"),
client: client,
lastSequence: -1,
}
}

func (p *provider) GetSubscriptions(ctx context.Context, cycle fetching.CycleMetadata) (map[string]Subscription, error) {
return p.cachedSubscriptions, p.maybeScan(ctx, cycle)
err := p.maybeScan(ctx, cycle)
return p.cachedSubscriptions, err
}

func (p *provider) maybeScan(ctx context.Context, cycle fetching.CycleMetadata) error {
Expand All @@ -84,15 +84,17 @@ func (p *provider) maybeScan(ctx context.Context, cycle fetching.CycleMetadata)
return nil
}

p.lastSequence = cycle.Sequence
if err := p.scan(ctx); err != nil {
if p.cachedSubscriptions == nil {
return fmt.Errorf("failed to scan subscriptions: %w", err)
}

p.lastSequence = cycle.Sequence
p.log.Errorf("Failed to scan subscriptions, re-using cached values: %v", err)
return nil
}

p.lastSequence = cycle.Sequence
return nil
}

Expand Down
197 changes: 197 additions & 0 deletions resources/providers/azurelib/governance/management_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package governance

import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"testing"
"time"

"github.com/elastic/elastic-agent-libs/atomic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/elastic/cloudbeat/resources/fetching"
"github.com/elastic/cloudbeat/resources/providers/azurelib/inventory"
"github.com/elastic/cloudbeat/resources/utils/testhelper"
)

func Test_provider_GetSubscriptions(t *testing.T) {
ctx := context.Background()
err1 := errors.New("some error 1")
err2 := errors.New("some error 2")
assets := []inventory.AzureAsset{
generateManagementGroupAsset(1),
generateManagementGroupAsset(2),
generateManagementGroupAsset(3),
generateManagementGroupAsset(4),
generateManagementGroupAsset(5),
generateSubscriptionAsset(1, 1),
generateSubscriptionAsset(2, 1),
generateSubscriptionAsset(3, 4),
generateSubscriptionAsset(4, 5),
generateSubscriptionAsset(5, 4),
generateSubscriptionAsset(6, 2),
}
rand.Shuffle(len(assets), func(i, j int) { // shuffle assets, as order shouldn't matter
assets[i], assets[j] = assets[j], assets[i]
})
expectedSubscriptions := map[string]Subscription{
"sub-id-1": generateSubscription(1, 1),
"sub-id-2": generateSubscription(2, 1),
"sub-id-3": generateSubscription(3, 4),
"sub-id-4": generateSubscription(4, 5),
"sub-id-5": generateSubscription(5, 4),
"sub-id-6": generateSubscription(6, 2),
}

t.Run("no assets", func(t *testing.T) {
p := NewProvider(testhelper.NewLogger(t), mockClient(t, nil, nil))
subs, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, map[string]Subscription{}, subs)

subsSame, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, subs, subsSame)
})

t.Run("error on first call", func(t *testing.T) {
p := NewProvider(testhelper.NewLogger(t), mockClient(t, nil, err1))
_, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 10})
require.ErrorIs(t, err, err1)

p.(*provider).client = mockClient(t, nil, err2)
_, err = p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.ErrorIs(t, err, err2)

p.(*provider).client = mockClient(t, assets, nil)
got, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, expectedSubscriptions, got)
})

t.Run("error on later call", func(t *testing.T) {
p := NewProvider(testhelper.NewLogger(t), mockClient(t, assets, nil))
got, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, expectedSubscriptions, got)

p.(*provider).client = mockClient(t, nil, err1)
got, err = p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 10})
require.NoError(t, err)
assert.Equal(t, expectedSubscriptions, got)
})

t.Run("lock", func(t *testing.T) {
firstRun := atomic.NewBool(false)
m := inventory.NewMockProviderAPI(t)
m.EXPECT().
ListAllAssetTypesByName(mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, s string, strings []string) ([]inventory.AzureAsset, error) {
if firstRun.CAS(false, true) {
time.Sleep(50 * time.Millisecond)
return assets, nil
}
return nil, err1
})
p := NewProvider(testhelper.NewLogger(t), m)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

got, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, expectedSubscriptions, got)
}()

got, err := p.GetSubscriptions(ctx, fetching.CycleMetadata{Sequence: 1})
require.NoError(t, err)
assert.Equal(t, expectedSubscriptions, got)

wg.Wait()
})
}

func mockClient(t *testing.T, assets []inventory.AzureAsset, err error) *inventory.MockProviderAPI {
t.Helper()
client := inventory.NewMockProviderAPI(t)
client.EXPECT().
ListAllAssetTypesByName(mock.Anything, mock.Anything, mock.Anything).
Return(assets, err).
Once()
return client
}

func generateManagementGroup(id int) ManagementGroup {
return ManagementGroup{
ID: fmtField("mg-id", id),
DisplayName: fmtField("mg-display-name", id),
}
}

func generateSubscription(id int, parentId int) Subscription {
return Subscription{
ID: fmtField("sub-id", id),
DisplayName: fmtField("sub-display-name", id),
MG: generateManagementGroup(parentId),
}
}

func generateManagementGroupAsset(id int) inventory.AzureAsset {
return inventory.AzureAsset{
Id: fmtField("mg-id", id),
Name: fmtField("mg-name", id),
DisplayName: fmtField("mg-display-name", id),
Location: "location",
Properties: nil,
TenantId: "tenant-id",
Type: "microsoft.management/managementgroups",
}
}

func generateSubscriptionAsset(id int, parentId int) inventory.AzureAsset {
subId := fmtField("sub-id", id)
return inventory.AzureAsset{
Id: subId,
Name: fmtField("sub-display-name", id),
Location: "location",
Properties: map[string]any{
"managementGroupAncestorsChain": []any{
map[string]any{
"displayName": fmtField("mg-display-name", parentId),
"name": fmtField("mg-name", parentId),
},
},
},
SubscriptionId: subId,
TenantId: "tenant-id",
Type: "microsoft.resources/subscriptions",
}
}

func fmtField(s string, i int) string {
return fmt.Sprintf("%s-%d", s, i)
}

0 comments on commit 14bf2b6

Please sign in to comment.