diff --git a/internal/auth/oidc_clients_store.go b/internal/auth/oidc_clients_store.go index 5e998d5..291a3b0 100644 --- a/internal/auth/oidc_clients_store.go +++ b/internal/auth/oidc_clients_store.go @@ -2,6 +2,7 @@ package auth import ( "strings" + "sync" ) type OIDCClientConfig struct { @@ -18,6 +19,8 @@ type OIDCClientsStore interface { type StaticOIDCClientsStore struct { clients map[string]OIDCClientConfig + + mtx sync.RWMutex } func NewStaticOIDCClientStore(config map[string]OIDCClientConfig) *StaticOIDCClientsStore { @@ -29,6 +32,9 @@ func NewEmptyStaticOIDCClientStore() *StaticOIDCClientsStore { } func (ocf *StaticOIDCClientsStore) GetClient(domain string) (*OIDCClientConfig, error) { + ocf.mtx.RLock() + defer ocf.mtx.RUnlock() + if config, ok := ocf.clients[domain]; ok { return &config, nil } @@ -36,11 +42,14 @@ func (ocf *StaticOIDCClientsStore) GetClient(domain string) (*OIDCClientConfig, } func (ocf *StaticOIDCClientsStore) AddClient(domain string, clientid string, clientsecret string, redirecturl string) { + ocf.mtx.Lock() + defer ocf.mtx.Unlock() + if _, ok := ocf.clients[domain]; !ok { - ocf.clients[strings.Clone(domain)] = OIDCClientConfig { - ClientID: strings.Clone(clientid), + ocf.clients[strings.Clone(domain)] = OIDCClientConfig{ + ClientID: strings.Clone(clientid), ClientSecret: strings.Clone(clientsecret), - RedirectURL: strings.Clone(redirecturl), + RedirectURL: strings.Clone(redirecturl), } } } diff --git a/internal/auth/oidc_clients_store_test.go b/internal/auth/oidc_clients_store_test.go new file mode 100644 index 0000000..29a6519 --- /dev/null +++ b/internal/auth/oidc_clients_store_test.go @@ -0,0 +1,52 @@ +package auth + +import ( + "strconv" + "sync" + "testing" +) + +func TestStaticOIDCClientsStoreRace(t *testing.T) { + var wg = &sync.WaitGroup{} + var expectedValue OIDCClientConfig + var store = NewEmptyStaticOIDCClientStore() + const steps = 10000 + + // One Goroutine changes the store state while 2 other try to read from it. + wg.Add(1) + go func() { + for i := 0; i < steps; i++ { + // Something comes with requests for a new and a valid domain, + // so it is being added to the store. + expectedValue.ClientID = "client-id" + expectedValue.ClientSecret = "client-secret" + expectedValue.RedirectURL = "https://" + strconv.Itoa(i) + ".example.com/redirect" + domain := strconv.Itoa(i) + ".example.com" + + store.AddClient(domain, expectedValue.ClientID, expectedValue.ClientSecret, expectedValue.RedirectURL) + } + + wg.Done() + }() + + // Read and compare. + var found bool + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + for i := 0; i < steps; i++ { + _, err := store.GetClient("100000.example.com") + if err == nil { + found = true + break + } + } + + wg.Done() + }() + } + + if found { + t.Fatal("Received a value while should get ErrOIDCClientConfigNotFound") + } +}