Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mvbrock committed Jan 23, 2025
1 parent 64ca538 commit 9bb0b60
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 120 deletions.
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ require (
software.sslmate.com/src/go-pkcs12 v0.5.0
)

require github.com/xanzy/go-gitlab v0.115.0 // indirect

require (
cel.dev/expr v0.19.1 // indirect
cloud.google.com/go v0.117.0 // indirect
Expand Down Expand Up @@ -524,6 +522,7 @@ require (
github.com/vbatts/tar-split v0.11.5 // indirect
github.com/weppos/publicsuffix-go v0.30.3-0.20240510084413-5f1d03393b3d // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/xanzy/go-gitlab v0.115.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
Expand Down
8 changes: 0 additions & 8 deletions lib/msgraph/paginated.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,6 @@ func (c *Client) IterateServicePrincipals(ctx context.Context, f func(principal
return iterateSimple(c, ctx, "servicePrincipals", f)
}

// IterateUserMembership lists all group memberships for a given user ID as directory objects.
// `f` will be called for each directory object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
// Ref: [https://learn.microsoft.com/en-us/graph/api/group-list-memberof].
func (c *Client) IterateUserMembership(ctx context.Context, userID string, f func(object *DirectoryObject) bool) error {
return iterateSimple(c, ctx, path.Join("users", userID, "memberOf"), f)
}

// IterateGroupMembers lists all members for the given Entra ID group using pagination.
// `f` will be called for each object in the result set.
// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop).
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/discovery/access_graph_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ func grpcCredentials(config AccessGraphConfig, getCert func() (*tls.Certificate,
return grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), nil
}

func (s *Server) initAccessGraphWatchers(ctx context.Context, cfg *Config) error {
fetchers, err := s.accessGraphFetchersFromMatchers(ctx, cfg.Matchers, "" /* discoveryConfigName */)
func (s *Server) initTAGAWSWatchers(ctx context.Context, cfg *Config) error {
fetchers, err := s.accessGraphAWSFetchersFromMatchers(ctx, cfg.Matchers, "" /* discoveryConfigName */)
if err != nil {
s.Log.ErrorContext(ctx, "Error initializing access graph fetchers", "error", err)
}
Expand Down Expand Up @@ -482,8 +482,8 @@ func (s *Server) initAccessGraphWatchers(ctx context.Context, cfg *Config) error
return nil
}

// accessGraphFetchersFromMatchers converts Matchers into a set of AWS Sync Fetchers.
func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers Matchers, discoveryConfigName string) ([]aws_sync.AWSSync, error) {
// accessGraphAWSFetchersFromMatchers converts Matchers into a set of AWS Sync Fetchers.
func (s *Server) accessGraphAWSFetchersFromMatchers(ctx context.Context, matchers Matchers, discoveryConfigName string) ([]aws_sync.AWSSync, error) {
var fetchers []aws_sync.AWSSync
var errs []error
if matchers.AccessGraph == nil {
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ func New(ctx context.Context, cfg *Config) (*Server, error) {
return nil, trace.Wrap(err)
}

if err := s.initAccessGraphWatchers(s.ctx, cfg); err != nil {
if err := s.initTAGAWSWatchers(s.ctx, cfg); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -1822,7 +1822,7 @@ func (s *Server) upsertDynamicMatchers(ctx context.Context, dc *discoveryconfig.
s.dynamicDatabaseFetchers[dc.GetName()] = databaseFetchers
s.muDynamicDatabaseFetchers.Unlock()

awsSyncMatchers, err := s.accessGraphFetchersFromMatchers(
awsSyncMatchers, err := s.accessGraphAWSFetchersFromMatchers(
ctx, matchers, dc.GetName(),
)
if err != nil {
Expand Down
14 changes: 0 additions & 14 deletions lib/utils/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,3 @@ func FromSlice[T any](r []T, key func(T) string) map[string]T {

return out
}

// DeduplicateKey returns a deduplicated slice by comparing key values from the key function
func DeduplicateKey[T any](s []T, key func(T) string) []T {
out := make([]T, 0, len(s))
seen := make(map[string]struct{})
for _, v := range s {
if _, ok := seen[key(v)]; ok {
continue
}
seen[key(v)] = struct{}{}
out = append(out, v)
}
return out
}
90 changes: 0 additions & 90 deletions lib/utils/slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package utils

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -42,92 +41,3 @@ func TestSlice(t *testing.T) {
pool.Put(slice)
}
}

// TestDuplicateKey tests slice deduplication via key function
func TestDeduplicateKey(t *testing.T) {
t.Parallel()

stringTests := []struct {
name string
slice []string
keyFn func(string) string
expected []string
}{
{
name: "EmptyStringSlice",
slice: []string{},
keyFn: func(s string) string { return s },
expected: []string{},
},
{
name: "NoStringDuplicates",
slice: []string{"foo", "bar", "baz"},
keyFn: func(s string) string { return s },
expected: []string{"foo", "bar", "baz"},
},
{
name: "StringDuplicates",
slice: []string{"foo", "bar", "bar", "bar", "baz", "baz"},
keyFn: func(s string) string { return s },
expected: []string{"foo", "bar", "baz"},
},
{
name: "StringDuplicatesWeirdKeyFn",
slice: []string{"foo", "bar", "bar", "bar", "baz", "baz"},
keyFn: func(s string) string { return "huh" },
expected: []string{"foo"},
},
}
for _, tt := range stringTests {
t.Run(tt.name, func(t *testing.T) {
res := DeduplicateKey(tt.slice, tt.keyFn)
require.Equal(t, tt.expected, res)
})
}

type dedupeStruct struct {
a string
b int
c bool
}
dedupeStructKeyFn := func(d dedupeStruct) string { return fmt.Sprintf("%s:%d:%v", d.a, d.b, d.c) }
structTests := []struct {
name string
slice []dedupeStruct
keyFn func(d dedupeStruct) string
expected []dedupeStruct
}{
{
name: "EmptySlice",
slice: []dedupeStruct{},
keyFn: dedupeStructKeyFn,
expected: []dedupeStruct{},
},
{
name: "NoStructDuplicates",
slice: []dedupeStruct{
{a: "foo", b: 1, c: true},
{a: "foo", b: 1, c: false},
{a: "foo", b: 2, c: true},
{a: "bar", b: 1, c: true},
{a: "bar", b: 1, c: false},
{a: "bar", b: 2, c: true},
},
keyFn: dedupeStructKeyFn,
expected: []dedupeStruct{
{a: "foo", b: 1, c: true},
{a: "foo", b: 1, c: false},
{a: "foo", b: 2, c: true},
{a: "bar", b: 1, c: true},
{a: "bar", b: 1, c: false},
{a: "bar", b: 2, c: true},
},
},
}
for _, tt := range structTests {
t.Run(tt.name, func(t *testing.T) {
res := DeduplicateKey(tt.slice, tt.keyFn)
require.Equal(t, tt.expected, res)
})
}
}

0 comments on commit 9bb0b60

Please sign in to comment.