diff --git a/auth/github/client_test.go b/auth/github/client_test.go index 5dcdf6a2..d845ee45 100644 --- a/auth/github/client_test.go +++ b/auth/github/client_test.go @@ -156,6 +156,8 @@ func TestClient_Options(t *testing.T) { } func TestClient_GetToken(t *testing.T) { + g := NewWithT(t) + expiresAt := time.Now().UTC().Add(time.Hour) tests := []struct { name string @@ -180,14 +182,17 @@ func TestClient_GetToken(t *testing.T) { { name: "Get cached token", opts: []OptFunc{func(client *Client) { - c := cache.NewTokenCache(1) - c.GetOrSet(context.Background(), client.buildCacheKey(), func(context.Context) (cache.Token, error) { + c, err := cache.NewTokenCache(1) + g.Expect(err).NotTo(HaveOccurred()) + _, ok, err := c.GetOrSet(context.Background(), client.buildCacheKey(), func(context.Context) (cache.Token, error) { return &AppToken{ Token: "access-token", ExpiresAt: expiresAt, }, nil }) - client.cache = c + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(ok).To(BeFalse()) + WithCache(c, "", "", "")(client) }}, statusCode: http.StatusInternalServerError, // error status code to make the test fail if the token is not cached wantAppToken: &AppToken{ diff --git a/auth/go.mod b/auth/go.mod index 5f1e25b9..7d5fcb93 100644 --- a/auth/go.mod +++ b/auth/go.mod @@ -11,7 +11,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.2 github.com/bradleyfalzon/ghinstallation/v2 v2.14.0 - github.com/fluxcd/pkg/cache v0.4.0 + github.com/fluxcd/pkg/cache v0.5.0 github.com/fluxcd/pkg/ssh v0.17.0 github.com/onsi/gomega v1.36.2 golang.org/x/net v0.35.0 diff --git a/cache/store.go b/cache/store.go index 699402d2..b4f10bd6 100644 --- a/cache/store.go +++ b/cache/store.go @@ -47,6 +47,7 @@ type storeOptions struct { interval time.Duration registerer prometheus.Registerer metricsPrefix string + maxDuration time.Duration involvedObject *InvolvedObject debugKey string debugValueFunc func(any) any @@ -88,6 +89,14 @@ func WithMetricsPrefix(prefix string) Options { } } +// WithMaxDuration sets the maximum duration for the cache items. +func WithMaxDuration(duration time.Duration) Options { + return func(o *storeOptions) error { + o.maxDuration = duration + return nil + } +} + // WithInvolvedObject sets the involved object for the cache metrics. func WithInvolvedObject(kind, name, namespace string) Options { return func(o *storeOptions) error { diff --git a/cache/token.go b/cache/token.go index aeadd19e..c961f098 100644 --- a/cache/token.go +++ b/cache/token.go @@ -21,6 +21,11 @@ import ( "time" ) +// TokenMaxDuration is the maximum duration that a token can have in the +// TokenCache. This is used to cap the duration of tokens to avoid storing +// tokens that are valid for too long. +const TokenMaxDuration = time.Hour + // Token is an interface that represents an access token that can be used // to authenticate with a cloud provider. The only common method is to get the // duration of the token, because different providers may have different ways to @@ -45,7 +50,8 @@ type Token interface { // lifetime, which is the same strategy used by kubelet for rotating // ServiceAccount tokens. type TokenCache struct { - cache *LRU[*tokenItem] + cache *LRU[*tokenItem] + maxDuration time.Duration } type tokenItem struct { @@ -55,9 +61,20 @@ type tokenItem struct { } // NewTokenCache returns a new TokenCache with the given capacity. -func NewTokenCache(capacity int, opts ...Options) *TokenCache { - cache, _ := NewLRU[*tokenItem](capacity, opts...) - return &TokenCache{cache: cache} +func NewTokenCache(capacity int, opts ...Options) (*TokenCache, error) { + o := storeOptions{maxDuration: TokenMaxDuration} + o.apply(opts...) + + if o.maxDuration > TokenMaxDuration { + o.maxDuration = TokenMaxDuration + } + + cache, err := NewLRU[*tokenItem](capacity, opts...) + if err != nil { + return nil, err + } + + return &TokenCache{cache, o.maxDuration}, nil } // GetOrSet returns the token for the given key if present and not expired, or @@ -112,6 +129,10 @@ func (c *TokenCache) newItem(token Token) *tokenItem { // Ref: https://github.com/kubernetes/kubernetes/blob/4032177faf21ae2f99a2012634167def2376b370/pkg/kubelet/token/token_manager.go#L172-L174 d := (token.GetDuration() * 8) / 10 + if m := c.maxDuration; d > m { + d = m + } + mono := time.Now().Add(d) unix := time.Unix(mono.Unix(), 0) diff --git a/cache/token_test.go b/cache/token_test.go index e2243756..53cb2663 100644 --- a/cache/token_test.go +++ b/cache/token_test.go @@ -18,6 +18,7 @@ package cache_test import ( "context" + "fmt" "testing" "time" @@ -35,11 +36,14 @@ func (t *testToken) GetDuration() time.Duration { } func TestTokenCache_Lifecycle(t *testing.T) { + t.Parallel() + g := NewWithT(t) ctx := context.Background() - tc := cache.NewTokenCache(1) + tc, err := cache.NewTokenCache(1) + g.Expect(err).NotTo(HaveOccurred()) token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return &testToken{duration: 2 * time.Second}, nil @@ -48,19 +52,108 @@ func TestTokenCache_Lifecycle(t *testing.T) { g.Expect(retrieved).To(BeFalse()) g.Expect(err).To(BeNil()) - time.Sleep(4 * time.Second) + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 2 * time.Second})) + g.Expect(retrieved).To(BeTrue()) + + time.Sleep(2 * time.Second) token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { - return &testToken{duration: 100 * time.Second}, nil + return &testToken{duration: time.Hour}, nil }) - g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second})) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) g.Expect(retrieved).To(BeFalse()) g.Expect(err).To(BeNil()) +} - time.Sleep(2 * time.Second) +func TestTokenCache_80PercentLifetime(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc, err := cache.NewTokenCache(1) + g.Expect(err).NotTo(HaveOccurred()) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: 5 * time.Second}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second})) + g.Expect(retrieved).To(BeFalse()) token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) - g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second})) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second})) g.Expect(retrieved).To(BeTrue()) - g.Expect(err).To(BeNil()) + + time.Sleep(4 * time.Second) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: time.Hour}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeFalse()) +} + +func TestTokenCache_MaxDuration(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc, err := cache.NewTokenCache(1, cache.WithMaxDuration(time.Second)) + g.Expect(err).NotTo(HaveOccurred()) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: time.Hour}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeFalse()) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeTrue()) + + time.Sleep(2 * time.Second) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: 10 * time.Millisecond}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 10 * time.Millisecond})) + g.Expect(retrieved).To(BeFalse()) +} + +func TestTokenCache_GetOrSet_Error(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc, err := cache.NewTokenCache(1) + g.Expect(err).NotTo(HaveOccurred()) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return nil, fmt.Errorf("failed") + }) + + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError("failed")) + g.Expect(token).To(BeNil()) + g.Expect(retrieved).To(BeFalse()) }