From c9ffccd02b8681381939b26d4fe039c5b07c3ace Mon Sep 17 00:00:00 2001 From: Matheus Pimenta Date: Sat, 1 Mar 2025 02:32:22 +0000 Subject: [PATCH] Add max duration to token cache Signed-off-by: Matheus Pimenta --- cache/store.go | 9 ++++ cache/token.go | 22 +++++++++- cache/token_test.go | 101 +++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 124 insertions(+), 8 deletions(-) diff --git a/cache/store.go b/cache/store.go index 699402d2..b4f10bd6 100644 --- a/cache/store.go +++ b/cache/store.go @@ -47,6 +47,7 @@ type storeOptions struct { interval time.Duration registerer prometheus.Registerer metricsPrefix string + maxDuration time.Duration involvedObject *InvolvedObject debugKey string debugValueFunc func(any) any @@ -88,6 +89,14 @@ func WithMetricsPrefix(prefix string) Options { } } +// WithMaxDuration sets the maximum duration for the cache items. +func WithMaxDuration(duration time.Duration) Options { + return func(o *storeOptions) error { + o.maxDuration = duration + return nil + } +} + // WithInvolvedObject sets the involved object for the cache metrics. func WithInvolvedObject(kind, name, namespace string) Options { return func(o *storeOptions) error { diff --git a/cache/token.go b/cache/token.go index aeadd19e..f2004af7 100644 --- a/cache/token.go +++ b/cache/token.go @@ -21,6 +21,11 @@ import ( "time" ) +// TokenMaxDuration is the maximum duration that a token can have in the +// TokenCache. This is used to cap the duration of tokens to avoid storing +// tokens that are valid for too long. +const TokenMaxDuration = time.Hour + // Token is an interface that represents an access token that can be used // to authenticate with a cloud provider. The only common method is to get the // duration of the token, because different providers may have different ways to @@ -45,7 +50,8 @@ type Token interface { // lifetime, which is the same strategy used by kubelet for rotating // ServiceAccount tokens. type TokenCache struct { - cache *LRU[*tokenItem] + cache *LRU[*tokenItem] + maxDuration time.Duration } type tokenItem struct { @@ -56,8 +62,16 @@ type tokenItem struct { // NewTokenCache returns a new TokenCache with the given capacity. func NewTokenCache(capacity int, opts ...Options) *TokenCache { + o := storeOptions{maxDuration: TokenMaxDuration} + o.apply(opts...) + + if o.maxDuration > TokenMaxDuration { + o.maxDuration = TokenMaxDuration + } + cache, _ := NewLRU[*tokenItem](capacity, opts...) - return &TokenCache{cache: cache} + + return &TokenCache{cache, o.maxDuration} } // GetOrSet returns the token for the given key if present and not expired, or @@ -112,6 +126,10 @@ func (c *TokenCache) newItem(token Token) *tokenItem { // Ref: https://github.com/kubernetes/kubernetes/blob/4032177faf21ae2f99a2012634167def2376b370/pkg/kubelet/token/token_manager.go#L172-L174 d := (token.GetDuration() * 8) / 10 + if m := c.maxDuration; d > m { + d = m + } + mono := time.Now().Add(d) unix := time.Unix(mono.Unix(), 0) diff --git a/cache/token_test.go b/cache/token_test.go index e2243756..1162bb25 100644 --- a/cache/token_test.go +++ b/cache/token_test.go @@ -18,6 +18,7 @@ package cache_test import ( "context" + "fmt" "testing" "time" @@ -35,6 +36,8 @@ func (t *testToken) GetDuration() time.Duration { } func TestTokenCache_Lifecycle(t *testing.T) { + t.Parallel() + g := NewWithT(t) ctx := context.Background() @@ -48,19 +51,105 @@ func TestTokenCache_Lifecycle(t *testing.T) { g.Expect(retrieved).To(BeFalse()) g.Expect(err).To(BeNil()) - time.Sleep(4 * time.Second) + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 2 * time.Second})) + g.Expect(retrieved).To(BeTrue()) + + time.Sleep(2 * time.Second) token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { - return &testToken{duration: 100 * time.Second}, nil + return &testToken{duration: time.Hour}, nil }) - g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second})) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) g.Expect(retrieved).To(BeFalse()) g.Expect(err).To(BeNil()) +} - time.Sleep(2 * time.Second) +func TestTokenCache_80PercentLifetime(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc := cache.NewTokenCache(1) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: 5 * time.Second}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second})) + g.Expect(retrieved).To(BeFalse()) token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) - g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second})) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second})) g.Expect(retrieved).To(BeTrue()) - g.Expect(err).To(BeNil()) + + time.Sleep(4 * time.Second) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: time.Hour}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeFalse()) +} + +func TestTokenCache_MaxDuration(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc := cache.NewTokenCache(1, cache.WithMaxDuration(time.Second)) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: time.Hour}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeFalse()) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: time.Hour})) + g.Expect(retrieved).To(BeTrue()) + + time.Sleep(2 * time.Second) + + token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return &testToken{duration: 10 * time.Millisecond}, nil + }) + + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(token).To(Equal(&testToken{duration: 10 * time.Millisecond})) + g.Expect(retrieved).To(BeFalse()) +} + +func TestTokenCache_GetOrSet_Error(t *testing.T) { + t.Parallel() + + g := NewWithT(t) + + ctx := context.Background() + + tc := cache.NewTokenCache(1) + + token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { + return nil, fmt.Errorf("failed") + }) + + g.Expect(err).To(HaveOccurred()) + g.Expect(err).To(MatchError("failed")) + g.Expect(token).To(BeNil()) + g.Expect(retrieved).To(BeFalse()) }