Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve visibility of Redis counter #57

Merged
merged 15 commits into from
Oct 29, 2024
39 changes: 37 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"

"github.com/0xsequence/quotacontrol/proto"
"github.com/go-chi/httprate"
httprateredis "github.com/go-chi/httprate-redis"
"github.com/goware/logger"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/redis/go-redis/v9"
)
Expand Down Expand Up @@ -47,7 +51,38 @@ var _ QuotaCache = (*RedisCache)(nil)
var _ QuotaCache = (*LRU)(nil)
var _ UsageCache = (*RedisCache)(nil)

func NewLimitCounter(cfg RedisConfig, logger logger.Logger) httprate.LimitCounter {
if !cfg.Enabled {
return nil
}
return httprateredis.NewCounter(&httprateredis.Config{
Host: cfg.Host,
Port: cfg.Port,
MaxIdle: cfg.MaxIdle,
MaxActive: cfg.MaxActive,
DBIndex: cfg.DBIndex,
OnError: func(err error) {
if logger != nil {
logger.Error("redis counter error", slog.Any("error", err))
}
},
OnFallbackChange: func(fallback bool) {
if logger != nil {
logger.Warn("redis counter fallback", slog.Bool("fallback", fallback))
}
},
})
}

const (
defaultExpRedis = time.Hour
defaultExpLRU = time.Minute
)

func NewRedisCache(redisClient *redis.Client, ttl time.Duration) *RedisCache {
if ttl <= 0 {
ttl = defaultExpRedis
}
return &RedisCache{
client: redisClient,
ttl: ttl,
Expand Down Expand Up @@ -220,8 +255,8 @@ type LRU struct {
}

func NewLRU(cacheBackend QuotaCache, size int, ttl time.Duration) *LRU {
if ttl == 0 {
ttl = time.Minute * 5
if ttl <= 0 {
ttl = defaultExpLRU
}
lruCache := expirable.NewLRU[string, *proto.AccessQuota](size, nil, ttl)
return &LRU{
Expand Down
73 changes: 31 additions & 42 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/quotacontrol/internal/usage"
"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/goware/logger"
Expand All @@ -27,19 +28,15 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config, qc proto
MaxIdleConns: cfg.Redis.MaxIdle,
}

redisExpiration := time.Hour
if cfg.Redis.KeyTTL > 0 {
redisExpiration = cfg.Redis.KeyTTL
backend := NewRedisCache(redis.NewClient(&options), cfg.Redis.KeyTTL)
cache := Cache{
UsageCache: backend,
QuotaCache: backend,
PermissionCache: backend,
}
cache := NewRedisCache(redis.NewClient(&options), redisExpiration)

quotaCache := QuotaCache(cache)
// LRU cache for Quota
if cfg.LRUSize > 0 {
lruExpiration := time.Minute
if cfg.LRUExpiration.Duration > 0 {
lruExpiration = cfg.LRUExpiration.Duration
}
quotaCache = NewLRU(quotaCache, cfg.LRUSize, lruExpiration)
cache.QuotaCache = NewLRU(backend, cfg.LRUSize, cfg.LRUExpiration)
}

if qc == nil {
Expand All @@ -48,22 +45,18 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config, qc proto
})
}

var ticker *time.Ticker
if cfg.UpdateFreq.Duration > 0 {
ticker = time.NewTicker(cfg.UpdateFreq.Duration)
tick := time.Minute * 5
if cfg.UpdateFreq > 0 {
tick = cfg.UpdateFreq
}

return &Client{
cfg: cfg,
service: service,
usage: &usageTracker{
Usage: make(map[time.Time]usageRecord),
},
usageCache: cache,
quotaCache: quotaCache,
permCache: PermissionCache(cache),
cfg: cfg,
service: service,
usage: usage.NewTracker(),
cache: cache,
quotaClient: qc,
ticker: ticker,
ticker: time.NewTicker(tick),
logger: logger.With(slog.String("service", "quotacontrol")),
}
}
Expand All @@ -73,10 +66,8 @@ type Client struct {
logger logger.Logger

service proto.Service
usage *usageTracker
usageCache UsageCache
quotaCache QuotaCache
permCache PermissionCache
usage *usage.Tracker
cache Cache
quotaClient proto.QuotaControl

running int32
Expand Down Expand Up @@ -107,7 +98,7 @@ func (c *Client) GetDefaultUsage() int64 {
// FetchProjectQuota fetches the project quota from cache or from the quota server.
func (c *Client) FetchProjectQuota(ctx context.Context, projectID uint64, now time.Time) (*proto.AccessQuota, error) {
// fetch access quota
quota, err := c.quotaCache.GetProjectQuota(ctx, projectID)
quota, err := c.cache.QuotaCache.GetProjectQuota(ctx, projectID)
if err != nil {
logger := c.logger.With(
slog.String("op", "fetch_project_quota"),
Expand Down Expand Up @@ -135,7 +126,7 @@ func (c *Client) FetchKeyQuota(ctx context.Context, accessKey, origin string, no
slog.String("access_key", accessKey),
)
// fetch access quota
quota, err := c.quotaCache.GetAccessQuota(ctx, accessKey)
quota, err := c.cache.QuotaCache.GetAccessQuota(ctx, accessKey)
if err != nil {
if !errors.Is(err, proto.ErrAccessKeyNotFound) {
logger.Warn("unexpected cache error", slog.Any("error", err))
Expand Down Expand Up @@ -167,13 +158,13 @@ func (c *Client) FetchUsage(ctx context.Context, quota *proto.AccessQuota, now t
)

for i := range 3 {
usage, err := c.usageCache.PeekUsage(ctx, key)
usage, err := c.cache.UsageCache.PeekUsage(ctx, key)
if err != nil {
// ping the server to prepare usage
if errors.Is(err, ErrCachePing) {
if _, err := c.quotaClient.PrepareUsage(ctx, quota.AccessKey.ProjectID, quota.Cycle, now); err != nil {
logger.Error("unexpected client error", slog.Any("error", err))
if _, err := c.usageCache.ClearUsage(ctx, key); err != nil {
if _, err := c.cache.UsageCache.ClearUsage(ctx, key); err != nil {
logger.Error("unexpected cache error", slog.Any("error", err))
}
return 0, nil
Expand Down Expand Up @@ -218,7 +209,7 @@ func (c *Client) FetchPermission(ctx context.Context, projectID uint64) (proto.U
slog.String("user_id", userID),
)
// Check short-lived cache if requested. Note using the cache TTL from config (default 1m).
perm, access, err := c.permCache.GetUserPermission(ctx, projectID, userID)
perm, access, err := c.cache.PermissionCache.GetUserPermission(ctx, projectID, userID)
if err != nil {
// log the error, but don't stop
logger.Error("unexpected cache error", slog.Any("error", err))
Expand Down Expand Up @@ -255,7 +246,7 @@ func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, cost
key := getQuotaKey(quota.AccessKey.ProjectID, quota.Cycle, now)

for i := range 3 {
total, err := c.usageCache.SpendUsage(ctx, key, cost, cfg.OverMax)
total, err := c.cache.UsageCache.SpendUsage(ctx, key, cost, cfg.OverMax)
if err != nil {
// limit exceeded
if errors.Is(err, proto.ErrLimitExceeded) {
Expand All @@ -266,7 +257,7 @@ func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, cost
if errors.Is(err, ErrCachePing) {
if _, err := c.quotaClient.PrepareUsage(ctx, quota.AccessKey.ProjectID, quota.Cycle, now); err != nil {
logger.Error("unexpected client error", slog.Any("error", err))
if _, err := c.usageCache.ClearUsage(ctx, key); err != nil {
if _, err := c.cache.UsageCache.ClearUsage(ctx, key); err != nil {
logger.Error("unexpected cache error", slog.Any("error", err))
}
return false, 0, nil
Expand Down Expand Up @@ -306,11 +297,11 @@ func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, cost
}

func (c *Client) ClearQuotaCacheByProjectID(ctx context.Context, projectID uint64) error {
return c.quotaCache.DeleteProjectQuota(ctx, projectID)
return c.cache.QuotaCache.DeleteProjectQuota(ctx, projectID)
}

func (c *Client) ClearQuotaCacheByAccessKey(ctx context.Context, accessKey string) error {
return c.quotaCache.DeleteAccessQuota(ctx, accessKey)
return c.cache.QuotaCache.DeleteAccessQuota(ctx, accessKey)
}

func (c *Client) validateAccessKey(access *proto.AccessKey, origin string) (err error) {
Expand Down Expand Up @@ -342,9 +333,7 @@ func (c *Client) Run(ctx context.Context) error {
<-ctx.Done()
c.Stop(context.Background())
}()
if c.ticker == nil {
return nil
}

// Start the sync
for range c.ticker.C {
if err := c.usage.SyncUsage(ctx, c.quotaClient, c.service); err != nil {
Expand All @@ -365,9 +354,9 @@ func (c *Client) Stop(timeoutCtx context.Context) {
logger := c.logger.With("op", "stop")

logger.Info("stopping...")
if c.ticker != nil {
c.ticker.Stop()
}

c.ticker.Stop()

if err := c.usage.SyncUsage(timeoutCtx, c.quotaClient, c.service); err != nil {
logger.Error("sync usage", slog.Any("error", err))
}
Expand Down
32 changes: 0 additions & 32 deletions common.go

This file was deleted.

94 changes: 0 additions & 94 deletions common_test.go

This file was deleted.

Loading
Loading