diff --git a/go.mod b/go.mod index 14df794..b8a4739 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/dapr/kit go 1.20 require ( + github.com/alphadose/haxmap v1.3.1 github.com/cenkalti/backoff/v4 v4.2.1 github.com/fsnotify/fsnotify v1.7.0 github.com/lestrrat-go/httprc v1.0.4 diff --git a/go.sum b/go.sum index 27db2a2..481af75 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/alphadose/haxmap v1.3.1 h1:KmZh75duO1tC8pt3LmUwoTYiZ9sh4K52FX8p7/yrlqU= +github.com/alphadose/haxmap v1.3.1/go.mod h1:rjHw1IAqbxm0S3U5tD16GoKsiAd8FWx5BJ2IYqXwgmM= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/ttlcache/ttlcache.go b/ttlcache/ttlcache.go new file mode 100644 index 0000000..d6f751f --- /dev/null +++ b/ttlcache/ttlcache.go @@ -0,0 +1,175 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ttlcache + +import ( + "sync/atomic" + "time" + + "github.com/alphadose/haxmap" + kclock "k8s.io/utils/clock" +) + +// Cache is an efficient cache with a TTL. +type Cache[V any] struct { + m *haxmap.Map[string, cacheEntry[V]] + clock kclock.WithTicker + stopped atomic.Bool + runningCh chan struct{} + stopCh chan struct{} + maxTTL int64 +} + +// CacheOptions are options for NewCache. +type CacheOptions struct { + // Initial size for the cache. + // This is optional, and if empty will be left to the underlying library to decide. + InitialSize int32 + + // Interval to perform garbage collection. + // This is optional, and defaults to 150s (2.5 minutes). + CleanupInterval time.Duration + + // Maximum TTL value in seconds, if greater than 0 + MaxTTL int64 + + // Internal clock property, used for testing + clock kclock.WithTicker +} + +// NewCache returns a new cache with a TTL. +func NewCache[V any](opts CacheOptions) *Cache[V] { + var m *haxmap.Map[string, cacheEntry[V]] + if opts.InitialSize > 0 { + m = haxmap.New[string, cacheEntry[V]](uintptr(opts.InitialSize)) + } else { + m = haxmap.New[string, cacheEntry[V]]() + } + + if opts.CleanupInterval <= 0 { + opts.CleanupInterval = 150 * time.Second + } + + if opts.clock == nil { + opts.clock = kclock.RealClock{} + } + + c := &Cache[V]{ + m: m, + clock: opts.clock, + maxTTL: opts.MaxTTL, + stopCh: make(chan struct{}), + } + c.startBackgroundCleanup(opts.CleanupInterval) + return c +} + +// Get returns an item from the cache. +// Items that have expired are not returned. +func (c *Cache[V]) Get(key string) (v V, ok bool) { + val, ok := c.m.Get(key) + if !ok || !val.exp.After(c.clock.Now()) { + return v, false + } + return val.val, true +} + +// Set an item in the cache. +func (c *Cache[V]) Set(key string, val V, ttl int64) { + if ttl <= 0 { + panic("invalid TTL: must be > 0") + } + + if c.maxTTL > 0 && ttl > c.maxTTL { + ttl = c.maxTTL + } + + exp := c.clock.Now().Add(time.Duration(ttl) * time.Second) + c.m.Set(key, cacheEntry[V]{ + val: val, + exp: exp, + }) +} + +// Delete an item from the cache +func (c *Cache[V]) Delete(key string) { + c.m.Del(key) +} + +// Cleanup removes all expired entries from the cache. +func (c *Cache[V]) Cleanup() { + now := c.clock.Now() + + // Look for all expired keys and then remove them in bulk + // This is more efficient than removing keys one-by-one + // However, this could lead to a race condition where keys that are updated after ForEach ends are deleted nevertheless. + // This is considered acceptable in this case as this is just a cache. + keys := make([]string, 0, c.m.Len()) + c.m.ForEach(func(k string, v cacheEntry[V]) bool { + if v.exp.Before(now) { + keys = append(keys, k) + } + return true + }) + + c.m.Del(keys...) +} + +// Reset removes all entries from the cache. +func (c *Cache[V]) Reset() { + // Look for all keys and then remove them in bulk + // This is more efficient than removing keys one-by-one + // However, this could lead to a race condition where keys that are updated after ForEach ends are deleted nevertheless. + // This is considered acceptable in this case as this is just a cache. + keys := make([]string, 0, c.m.Len()) + c.m.ForEach(func(k string, v cacheEntry[V]) bool { + keys = append(keys, k) + return true + }) + + c.m.Del(keys...) +} + +func (c *Cache[V]) startBackgroundCleanup(d time.Duration) { + c.runningCh = make(chan struct{}) + go func() { + defer close(c.runningCh) + + t := c.clock.NewTicker(d) + defer t.Stop() + for { + select { + case <-c.stopCh: + // Stop the background goroutine + return + case <-t.C(): + c.Cleanup() + } + } + }() +} + +// Stop the cache, stopping the background garbage collection process. +func (c *Cache[V]) Stop() { + if c.stopped.CompareAndSwap(false, true) { + close(c.stopCh) + } + <-c.runningCh +} + +// Each item in the cache is stored in a cacheEntry, which includes the value as well as its expiration time. +type cacheEntry[V any] struct { + val V + exp time.Time +} diff --git a/ttlcache/ttlcache_test.go b/ttlcache/ttlcache_test.go new file mode 100644 index 0000000..08aa623 --- /dev/null +++ b/ttlcache/ttlcache_test.go @@ -0,0 +1,101 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ttlcache + +import ( + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +func TestCache(t *testing.T) { + clock := &clocktesting.FakeClock{} + clock.SetTime(time.Now()) + + cache := NewCache[string](CacheOptions{ + InitialSize: 10, + CleanupInterval: 20 * time.Second, + MaxTTL: 15, + clock: clock, + }) + defer cache.Stop() + + // Set values in the cache + cache.Set("key1", "val1", 2) + cache.Set("key2", "val2", 5) + cache.Set("key3", "val3", 30) // Max TTL is 15s + cache.Set("key4", "val4", 5) + + // Retrieve values + for i := 0; i < 16; i++ { + v, ok := cache.Get("key1") + if i < 2 { + require.True(t, ok) + require.Equal(t, "val1", v) + } else { + require.False(t, ok) + } + + v, ok = cache.Get("key2") + if i < 5 { + require.True(t, ok) + require.Equal(t, "val2", v) + } else { + require.False(t, ok) + } + + v, ok = cache.Get("key3") + if i < 15 { + require.True(t, ok) + require.Equal(t, "val3", v) + } else { + require.False(t, ok) + } + + v, ok = cache.Get("key4") + if i < 1 { + require.True(t, ok) + require.Equal(t, "val4", v) + + // Delete from the cache + cache.Delete("key4") + } else { + require.False(t, ok) + } + + // Advance the clock + clock.Step(time.Second) + runtime.Gosched() + time.Sleep(20 * time.Millisecond) + } + + // Values should still be in the cache as they haven't been cleaned up yet + require.EqualValues(t, 3, cache.m.Len()) + + // Advance the clock a bit more to make sure the cleanup runs + clock.Step(5 * time.Second) + + runtime.Gosched() + time.Sleep(20 * time.Millisecond) + + require.EventuallyWithT(t, func(c *assert.CollectT) { + if !assert.EqualValues(c, 0, cache.m.Len()) { + runtime.Gosched() + } + }, time.Second, 50*time.Millisecond) +}