Skip to content

Commit

Permalink
Fix flaky TestAWSOIDCRequiredVPCSHelper (#51121)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoandredinis authored Jan 20, 2025
1 parent 1b76f97 commit dd1ca21
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions lib/web/integrations_awsoidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,6 @@ func TestBuildListDatabasesConfigureIAMScript(t *testing.T) {
func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
t.Parallel()
ctx := context.Background()
env := newWebPack(t, 1)
clt := env.proxies[0].client

matchRegion := "us-east-1"
matchAccountId := "123456789012"
Expand All @@ -735,7 +733,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
AccountID: matchAccountId,
}

upsertDbSvcFn := func(vpcId string, matcher []*types.DatabaseResourceMatcher) {
dbServiceFor := func(vpcId string, matcher []*types.DatabaseResourceMatcher) *types.DatabaseServiceV1 {
if matcher == nil {
matcher = []*types.DatabaseResourceMatcher{
{
Expand All @@ -754,8 +752,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
ResourceMatchers: matcher,
})
require.NoError(t, err)
_, err = env.server.Auth().UpsertDatabaseService(ctx, svc)
require.NoError(t, err)
return svc
}

extractKeysFn := func(resp *ui.AWSOIDCRequiredVPCSResponse) []string {
Expand All @@ -777,24 +774,23 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
}

// Double check we start with 0 db svcs.
s, err := env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{
ResourceType: types.KindDatabaseService,
})
require.NoError(t, err)
require.Empty(t, s.Resources)
clt := &mockGetResources{
databaseServices: &proto.ListResourcesResponse{},
}

// All vpc's required.
resp, err := awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss)
require.NoError(t, err)
require.Len(t, resp.VPCMapOfSubnets, 5)
require.ElementsMatch(t, vpcs, extractKeysFn(resp))

// Insert two valid database services.
upsertDbSvcFn("vpc-1", nil)
upsertDbSvcFn("vpc-5", nil)
// Add some database services.
// Two valid database services.
validDBServiceVPC1 := dbServiceFor("vpc-1", nil)
validDBServiceVPC5 := dbServiceFor("vpc-5", nil)

// Insert two invalid database services.
upsertDbSvcFn("vpc-2", []*types.DatabaseResourceMatcher{
// Two invalid database services.
invalidDBServiceVPC2 := dbServiceFor("vpc-2", []*types.DatabaseResourceMatcher{
{
Labels: &types.Labels{
types.DiscoveryLabelAccountID: []string{matchAccountId},
Expand All @@ -803,7 +799,7 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
},
},
})
upsertDbSvcFn("vpc-2a", []*types.DatabaseResourceMatcher{
invalidDBServiceVPC2a := dbServiceFor("vpc-2a", []*types.DatabaseResourceMatcher{
{
Labels: &types.Labels{
types.DiscoveryLabelAccountID: []string{matchAccountId},
Expand All @@ -814,22 +810,27 @@ func TestAWSOIDCRequiredVPCSHelper(t *testing.T) {
},
})

// Double check services were created.
s, err = env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{
ResourceType: types.KindDatabaseService,
})
require.NoError(t, err)
require.Len(t, s.Resources, 4)
clt.databaseServices.Resources = append(clt.databaseServices.Resources,
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC1}},
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC5}},
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: invalidDBServiceVPC2}},
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: invalidDBServiceVPC2a}},
)

// Test that only 3 vpcs are required.
resp, err = awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss)
require.NoError(t, err)
require.ElementsMatch(t, []string{"vpc-2", "vpc-3", "vpc-4"}, extractKeysFn(resp))

// Insert the rest of db services
upsertDbSvcFn("vpc-2", nil)
upsertDbSvcFn("vpc-3", nil)
upsertDbSvcFn("vpc-4", nil)
validDBServiceVPC2 := dbServiceFor("vpc-2", nil)
validDBServiceVPC3 := dbServiceFor("vpc-3", nil)
validDBServiceVPC4 := dbServiceFor("vpc-4", nil)
clt.databaseServices.Resources = append(clt.databaseServices.Resources,
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC2}},
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC3}},
&proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseService{DatabaseService: validDBServiceVPC4}},
)

// Test no required vpcs.
resp, err = awsOIDCRequiredVPCSHelper(ctx, clt, req, rdss)
Expand Down Expand Up @@ -866,9 +867,16 @@ func TestAWSOIDCRequiredVPCSHelper_CombinedSubnetsForAVpcID(t *testing.T) {
}

type mockGetResources struct {
databaseServices *proto.ListResourcesResponse
}

func (m *mockGetResources) GetResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) {
switch req.ResourceType {
case types.KindDatabaseService:
if m.databaseServices != nil {
return m.databaseServices, nil
}
}
return &proto.ListResourcesResponse{}, nil
}

Expand Down

0 comments on commit dd1ca21

Please sign in to comment.