Skip to content

Commit

Permalink
Add max duration to token cache
Browse files Browse the repository at this point in the history
Signed-off-by: Matheus Pimenta <matheuscscp@gmail.com>
  • Loading branch information
matheuscscp committed Mar 4, 2025
1 parent ff04927 commit f993410
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 15 deletions.
11 changes: 8 additions & 3 deletions auth/github/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ func TestClient_Options(t *testing.T) {
}

func TestClient_GetToken(t *testing.T) {
g := NewWithT(t)

expiresAt := time.Now().UTC().Add(time.Hour)
tests := []struct {
name string
Expand All @@ -180,14 +182,17 @@ func TestClient_GetToken(t *testing.T) {
{
name: "Get cached token",
opts: []OptFunc{func(client *Client) {
c := cache.NewTokenCache(1)
c.GetOrSet(context.Background(), client.buildCacheKey(), func(context.Context) (cache.Token, error) {
c, err := cache.NewTokenCache(1)
g.Expect(err).NotTo(HaveOccurred())
_, ok, err := c.GetOrSet(context.Background(), client.buildCacheKey(), func(context.Context) (cache.Token, error) {
return &AppToken{
Token: "access-token",
ExpiresAt: expiresAt,
}, nil
})
client.cache = c
g.Expect(err).NotTo(HaveOccurred())
g.Expect(ok).To(BeFalse())
WithCache(c, "", "", "")(client)
}},
statusCode: http.StatusInternalServerError, // error status code to make the test fail if the token is not cached
wantAppToken: &AppToken{
Expand Down
2 changes: 1 addition & 1 deletion auth/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.2
github.com/bradleyfalzon/ghinstallation/v2 v2.14.0
github.com/fluxcd/pkg/cache v0.4.0
github.com/fluxcd/pkg/cache v0.5.0
github.com/fluxcd/pkg/ssh v0.17.0
github.com/onsi/gomega v1.36.2
golang.org/x/net v0.35.0
Expand Down
9 changes: 9 additions & 0 deletions cache/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type storeOptions struct {
interval time.Duration
registerer prometheus.Registerer
metricsPrefix string
maxDuration time.Duration
involvedObject *InvolvedObject
debugKey string
debugValueFunc func(any) any
Expand Down Expand Up @@ -88,6 +89,14 @@ func WithMetricsPrefix(prefix string) Options {
}
}

// WithMaxDuration sets the maximum duration for the cache items.
func WithMaxDuration(duration time.Duration) Options {
return func(o *storeOptions) error {
o.maxDuration = duration
return nil
}
}

// WithInvolvedObject sets the involved object for the cache metrics.
func WithInvolvedObject(kind, name, namespace string) Options {
return func(o *storeOptions) error {
Expand Down
29 changes: 25 additions & 4 deletions cache/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ import (
"time"
)

// TokenMaxDuration is the maximum duration that a token can have in the
// TokenCache. This is used to cap the duration of tokens to avoid storing
// tokens that are valid for too long.
const TokenMaxDuration = time.Hour

// Token is an interface that represents an access token that can be used
// to authenticate with a cloud provider. The only common method is to get the
// duration of the token, because different providers may have different ways to
Expand All @@ -45,7 +50,8 @@ type Token interface {
// lifetime, which is the same strategy used by kubelet for rotating
// ServiceAccount tokens.
type TokenCache struct {
cache *LRU[*tokenItem]
cache *LRU[*tokenItem]
maxDuration time.Duration
}

type tokenItem struct {
Expand All @@ -55,9 +61,20 @@ type tokenItem struct {
}

// NewTokenCache returns a new TokenCache with the given capacity.
func NewTokenCache(capacity int, opts ...Options) *TokenCache {
cache, _ := NewLRU[*tokenItem](capacity, opts...)
return &TokenCache{cache: cache}
func NewTokenCache(capacity int, opts ...Options) (*TokenCache, error) {
o := storeOptions{maxDuration: TokenMaxDuration}
o.apply(opts...)

if o.maxDuration > TokenMaxDuration {
o.maxDuration = TokenMaxDuration
}

cache, err := NewLRU[*tokenItem](capacity, opts...)
if err != nil {
return nil, err
}

return &TokenCache{cache, o.maxDuration}, nil
}

// GetOrSet returns the token for the given key if present and not expired, or
Expand Down Expand Up @@ -112,6 +129,10 @@ func (c *TokenCache) newItem(token Token) *tokenItem {
// Ref: https://github.com/kubernetes/kubernetes/blob/4032177faf21ae2f99a2012634167def2376b370/pkg/kubelet/token/token_manager.go#L172-L174
d := (token.GetDuration() * 8) / 10

if m := c.maxDuration; d > m {
d = m
}

mono := time.Now().Add(d)
unix := time.Unix(mono.Unix(), 0)

Expand Down
107 changes: 100 additions & 7 deletions cache/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package cache_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -35,11 +36,14 @@ func (t *testToken) GetDuration() time.Duration {
}

func TestTokenCache_Lifecycle(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc := cache.NewTokenCache(1)
tc, err := cache.NewTokenCache(1)
g.Expect(err).NotTo(HaveOccurred())

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 2 * time.Second}, nil
Expand All @@ -48,19 +52,108 @@ func TestTokenCache_Lifecycle(t *testing.T) {
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)
token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 2 * time.Second}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 100 * time.Second}, nil
return &testToken{duration: time.Hour}, nil
})
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
g.Expect(err).To(BeNil())
}

time.Sleep(2 * time.Second)
func TestTokenCache_80PercentLifetime(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc, err := cache.NewTokenCache(1)
g.Expect(err).NotTo(HaveOccurred())

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 5 * time.Second}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })
g.Expect(token).To(Equal(&testToken{duration: 100 * time.Second}))

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 5 * time.Second}))
g.Expect(retrieved).To(BeTrue())
g.Expect(err).To(BeNil())

time.Sleep(4 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_MaxDuration(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc, err := cache.NewTokenCache(1, cache.WithMaxDuration(time.Second))
g.Expect(err).NotTo(HaveOccurred())

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: time.Hour}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeFalse())

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) { return nil, nil })

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: time.Hour}))
g.Expect(retrieved).To(BeTrue())

time.Sleep(2 * time.Second)

token, retrieved, err = tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return &testToken{duration: 10 * time.Millisecond}, nil
})

g.Expect(err).NotTo(HaveOccurred())
g.Expect(token).To(Equal(&testToken{duration: 10 * time.Millisecond}))
g.Expect(retrieved).To(BeFalse())
}

func TestTokenCache_GetOrSet_Error(t *testing.T) {
t.Parallel()

g := NewWithT(t)

ctx := context.Background()

tc, err := cache.NewTokenCache(1)
g.Expect(err).NotTo(HaveOccurred())

token, retrieved, err := tc.GetOrSet(ctx, "test", func(context.Context) (cache.Token, error) {
return nil, fmt.Errorf("failed")
})

g.Expect(err).To(HaveOccurred())
g.Expect(err).To(MatchError("failed"))
g.Expect(token).To(BeNil())
g.Expect(retrieved).To(BeFalse())
}

0 comments on commit f993410

Please sign in to comment.