Skip to content

Commit

Permalink
Add max duration to token cache
Browse files Browse the repository at this point in the history
Signed-off-by: Matheus Pimenta <matheuscscp@gmail.com>
  • Loading branch information
matheuscscp committed Mar 4, 2025
1 parent ff04927 commit c9ffccd
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 8 deletions.
9 changes: 9 additions & 0 deletions cache/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type storeOptions struct {
interval time.Duration
registerer prometheus.Registerer
metricsPrefix string
maxDuration time.Duration
involvedObject *InvolvedObject
debugKey string
debugValueFunc func(any) any
Expand Down Expand Up @@ -88,6 +89,14 @@ func WithMetricsPrefix(prefix string) Options {
}
}

// WithMaxDuration sets the maximum duration for the cache items.
func WithMaxDuration(duration time.Duration) Options {
return func(o *storeOptions) error {
o.maxDuration = duration
return nil
}
}

// WithInvolvedObject sets the involved object for the cache metrics.
func WithInvolvedObject(kind, name, namespace string) Options {
return func(o *storeOptions) error {
Expand Down
22 changes: 20 additions & 2 deletions cache/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import (
"time"
)

// TokenMaxDuration is the maximum duration that a token can have in the
// TokenCache. This is used to cap the duration of tokens to avoid storing
// tokens that are valid for too long.
const TokenMaxDuration = time.Hour

// Token is an interface that represents an access token that can be used
// to authenticate with a cloud provider. The only common method is to get the
// duration of the token, because different providers may have different ways to
Expand All @@ -45,7 +50,8 @@ type Token interface {
// lifetime, which is the same strategy used by kubelet for rotating
// ServiceAccount tokens.
type TokenCache struct {
cache *LRU[*tokenItem]
cache *LRU[*tokenItem]
maxDuration time.Duration
}

type tokenItem struct {
Expand All @@ -56,8 +62,16 @@ type tokenItem struct {

// NewTokenCache returns a new TokenCache with the given capacity.
func NewTokenCache(capacity int, opts ...Options) *TokenCache {
o := storeOptions{maxDuration: TokenMaxDuration}
o.apply(opts...)

if o.maxDuration > TokenMaxDuration {
o.maxDuration = TokenMaxDuration
}

cache, _ := NewLRU[*tokenItem](capacity, opts...)
return &TokenCache{cache: cache}

return &TokenCache{cache, o.maxDuration}
}

// GetOrSet returns the token for the given key if present and not expired, or
Expand Down Expand Up @@ -112,6 +126,10 @@ func (c *TokenCache) newItem(token Token) *tokenItem {
// Ref: https://github.com/kubernetes/kubernetes/blob/4032177faf21ae2f99a2012634167def2376b370/pkg/kubelet/token/token_manager.go#L172-L174
d := (token.GetDuration() * 8) / 10

if m := c.maxDuration; d > m {
d = m
}

mono := time.Now().Add(d)
unix := time.Unix(mono.Unix(), 0)

Expand Down
101 changes: 95 additions & 6 deletions cache/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cache_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -35,6 +36,8 @@ func (t *testToken) GetDuration() time.Duration {
}

func TestTokenCache_Lifecycle(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()
Expand All @@ -48,19 +51,105 @@ func TestTokenCache_Lifecycle(t *testing.T) {
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)
token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 2 * time.Second}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 100 * time.Second}, nil
return &testToken{duration: time.Hour}, nil
})
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())
}

time.Sleep(2 * time.Second)
func TestTokenCache_80PercentLifetime(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1)

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 5 * time.Second}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeTrue())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_MaxDuration(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1, cache.WithMaxDuration(time.Second))

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 10 * time.Millisecond}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 10 * time.Millisecond}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_GetOrSet_Error(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1)

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return nil, fmt.Errorf("failed")
})

g.Expect(err).To(HaveOccurred())
g.Expect(err).To(MatchError("failed"))
g.Expect(token).To(BeNil())
g.Expect(retrieved).To(BeFalse())
}

0 comments on commit c9ffccd

Please sign in to comment.