Skip to content

Commit

Permalink
clean up jwt cache (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey authored Jan 13, 2025
1 parent 56cbb05 commit 8723bbb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
59 changes: 39 additions & 20 deletions jwt/jwtcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/go-jose/go-jose/v3"
"github.com/martinlindhe/base36"
"github.com/rs/zerolog/log"

"github.com/pomerium/cli/internal/cache"
"github.com/pomerium/pomerium/pkg/cryptutil"
Expand All @@ -23,20 +24,38 @@ var (
ErrNotFound = errors.New("not found")
)

// A JWTCache loads and stores JWTs.
type JWTCache interface {
// A Cache loads and stores JWTs.
type Cache interface {
DeleteJWT(key string) error
LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error
}

// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
var (
globalCacheOnce sync.Once
globalCache Cache
)

// GetCache gets the Cache. Either a local one is used or if that's not possible an in-memory one is used.
func GetCache() Cache {
globalCacheOnce.Do(func() {
if c, err := NewLocalCache(); err == nil {
globalCache = c
} else {
log.Error().Err(err).Msg("error creating local JWT cache, using in-memory JWT cache")
globalCache = NewMemoryCache()
}
})
return globalCache
}

// A LocalCache stores files in the user's cache directory.
type LocalCache struct {
dir string
}

// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
// NewLocalCache creates a new LocalCache.
func NewLocalCache() (*LocalCache, error) {
dir, err := cache.JWTsPath()
if err != nil {
return nil, err
Expand All @@ -47,13 +66,13 @@ func NewLocalJWTCache() (*LocalJWTCache, error) {
return nil, fmt.Errorf("error creating user cache directory: %w", err)
}

return &LocalJWTCache{
return &LocalCache{
dir: dir,
}, nil
}

// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
func (cache *LocalCache) DeleteJWT(key string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.Remove(path)
if os.IsNotExist(err) {
Expand All @@ -63,7 +82,7 @@ func (cache *LocalJWTCache) DeleteJWT(key string) error {
}

// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
func (cache *LocalCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key))
rawBS, err := os.ReadFile(path)
if os.IsNotExist(err) {
Expand All @@ -77,7 +96,7 @@ func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
}

// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
func (cache *LocalCache) StoreJWT(key string, rawJWT string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.WriteFile(path, []byte(rawJWT), 0o600)
if err != nil {
Expand All @@ -87,28 +106,28 @@ func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
return nil
}

func (cache *LocalJWTCache) hash(str string) string {
func (cache *LocalCache) hash(str string) string {
h := cryptutil.Hash("LocalJWTCache", []byte(str))
return base36.EncodeBytes(h)
}

func (cache *LocalJWTCache) fileName(key string) string {
func (cache *LocalCache) fileName(key string) string {
return cache.hash(key) + ".jwt"
}

// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
// A MemoryCache stores JWTs in an in-memory map.
type MemoryCache struct {
mu sync.Mutex
entries map[string]string
}

// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)}
// NewMemoryCache creates a new in-memory JWT cache.
func NewMemoryCache() *MemoryCache {
return &MemoryCache{entries: make(map[string]string)}
}

// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
func (cache *MemoryCache) DeleteJWT(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -117,7 +136,7 @@ func (cache *MemoryJWTCache) DeleteJWT(key string) error {
}

// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
func (cache *MemoryCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -130,7 +149,7 @@ func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
}

// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
func (cache *MemoryCache) StoreJWT(key string, rawJWT string) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand Down
4 changes: 2 additions & 2 deletions jwt/jwtcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
"github.com/stretchr/testify/assert"
)

func TestLocalJWTCache(t *testing.T) {
c := &LocalJWTCache{
func TestLocalCache(t *testing.T) {
c := &LocalCache{
dir: filepath.Join(os.TempDir(), uuid.New().String()),
}

Expand Down
13 changes: 3 additions & 10 deletions tunnel/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package tunnel
import (
"crypto/tls"

"github.com/rs/zerolog/log"

"github.com/pomerium/cli/jwt"
)

type config struct {
jwtCache jwt.JWTCache
jwtCache jwt.Cache
dstHost string
proxyHost string
serviceAccount string
Expand All @@ -20,12 +18,7 @@ type config struct {

func getConfig(options ...Option) *config {
cfg := new(config)
if jwtCache, err := jwt.NewLocalJWTCache(); err == nil {
WithJWTCache(jwtCache)(cfg)
} else {
log.Error().Err(err).Msg("error creating local JWT cache, using in-memory JWT cache")
WithJWTCache(jwt.NewMemoryJWTCache())(cfg)
}
WithJWTCache(jwt.GetCache())(cfg)
for _, o := range options {
o(cfg)
}
Expand All @@ -50,7 +43,7 @@ func WithDestinationHost(dstHost string) Option {
}

// WithJWTCache returns an option to configure the jwt cache.
func WithJWTCache(jwtCache jwt.JWTCache) Option {
func WithJWTCache(jwtCache jwt.Cache) Option {
return func(cfg *config) {
cfg.jwtCache = jwtCache
}
Expand Down

0 comments on commit 8723bbb

Please sign in to comment.