diff --git a/go.mod b/go.mod index 3f5a9bb..ed6fd8f 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.2 // replace github.com/0xsequence/authcontrol => ../authcontrol require ( - github.com/0xsequence/authcontrol v0.2.1 + github.com/0xsequence/authcontrol v0.3.1 github.com/0xsequence/go-sequence v0.43.0 github.com/alicebob/miniredis/v2 v2.33.0 github.com/go-chi/chi/v5 v5.1.0 diff --git a/go.sum b/go.sum index 5c88f35..d883079 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/0xsequence/authcontrol v0.2.1 h1:pu38FQqDb+RVUTxHMKw6dHgND7dqOQZ/Ur+sQ0vDp6w= -github.com/0xsequence/authcontrol v0.2.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8= +github.com/0xsequence/authcontrol v0.3.1 h1:2GS5WcmhtXrS3+qkAyGwQpnglijJZ7ZNnW6DS2qS6y0= +github.com/0xsequence/authcontrol v0.3.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8= github.com/0xsequence/ethkit v1.28.0 h1:11p4UXXvYnixQk01+qmAcOF71N9DlSeMcEMbaCPtjaY= github.com/0xsequence/ethkit v1.28.0/go.mod h1:rv0FAIyEyN0hhwGefbduAz4ujmyjyJXhCd6a0/yF3tk= github.com/0xsequence/go-sequence v0.43.0 h1:PErMuTg4PeaamJutEJ6tAjrFBA8z0t6lvT9LOVC5RMs= diff --git a/go.work.sum b/go.work.sum index b18f766..3f73a78 100644 --- a/go.work.sum +++ b/go.work.sum @@ -9,6 +9,10 @@ cloud.google.com/go/monitoring v1.21.0/go.mod h1:tuJ+KNDdJbetSsbSGTqnaBvbauS5kr3 cloud.google.com/go/storage v1.44.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= github.com/0xsequence/authcontrol v0.2.0 h1:7enDdLZSP3ngtMD7P6R6/yN1avgUFeHia0J8J8RKV5A= github.com/0xsequence/authcontrol v0.2.0/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8= +github.com/0xsequence/authcontrol v0.3.0 h1:/dWx7sV2jds0imVAdFL4ixcQmm9GT0UDXPFRaVFHTe0= +github.com/0xsequence/authcontrol v0.3.0/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8= +github.com/0xsequence/authcontrol v0.3.1 h1:2GS5WcmhtXrS3+qkAyGwQpnglijJZ7ZNnW6DS2qS6y0= +github.com/0xsequence/authcontrol v0.3.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8= github.com/0xsequence/go-ethauth v0.13.0/go.mod h1:f3kx39S9F+W+qvZEB6bkKKbpUstmyB7goUntO3wvlhg= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1/go.mod h1:itPGVDKf9cC/ov4MdvJ2QZ0khw4bfoo9jzwTJlaxy2k= diff --git a/handler_test.go b/handler_test.go index edfc5c5..d89f801 100644 --- a/handler_test.go +++ b/handler_test.go @@ -79,6 +79,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger) r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(authOptions)) r.Use(authcontrol.Session(authOptions)) r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(addCost(_credits * 2)) @@ -388,6 +389,7 @@ func TestJWT(t *testing.T) { } r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(authOptions)) r.Use(authcontrol.Session(authOptions)) r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(middleware.EnsureUsage(client, quotaOptions)) @@ -467,6 +469,7 @@ func TestJWTAccess(t *testing.T) { quotaOptions := middleware.Options{} r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(authOptions)) r.Use(authcontrol.Session(authOptions)) r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions)) @@ -573,9 +576,10 @@ func TestSession(t *testing.T) { limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger) r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(authOptions)) r.Use(authcontrol.Session(authOptions)) - r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(authcontrol.AccessControl(ACL, authOptions)) + r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions)) r.Handle("/*", &counter) @@ -583,7 +587,7 @@ func TestSession(t *testing.T) { ctx := context.Background() limit := proto.Limit{RateLimit: 100, FreeWarn: 5, FreeMax: 5, OverWarn: 7, OverMax: 10} server.Store.AddUser(ctx, UserAddress, false) - server.Store.AddProject(ctx, ProjectID) + server.Store.AddProject(ctx, ProjectID, nil) server.Store.SetAccessLimit(ctx, ProjectID, &limit) server.Store.SetUserPermission(ctx, ProjectID, WalletAddress, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: ProjectID}) server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: AccessKey, ProjectID: ProjectID}) @@ -664,10 +668,14 @@ func TestSession(t *testing.T) { assert.NoError(t, err) assert.Equal(t, quotaRPM, rateLimit) assert.Equal(t, quotaLimit, h.Get(middleware.HeaderQuotaLimit)) - case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User: + case proto.SessionType_Wallet, proto.SessionType_User: assert.True(t, ok) assert.NoError(t, err) - assert.Equal(t, accountRPM, rateLimit) + limit := accountRPM + if tc.AccessKey != "" { + limit = quotaRPM + } + assert.Equal(t, limit, rateLimit) case proto.SessionType_InternalService: assert.True(t, ok) assert.NoError(t, err) @@ -700,17 +708,18 @@ func TestSessionDisabled(t *testing.T) { limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger) r := chi.NewRouter() + r.Use(authcontrol.VerifyToken(authOptions)) r.Use(authcontrol.Session(authOptions)) + r.Use(authcontrol.AccessControl(ACL, authOptions)) r.Use(middleware.VerifyQuota(client, quotaOptions)) r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions)) - r.Use(authcontrol.AccessControl(ACL, authOptions)) r.Handle("/*", &counter) ctx := context.Background() limit := proto.Limit{RateLimit: 100, FreeWarn: 5, FreeMax: 5, OverWarn: 7, OverMax: 10} server.Store.AddUser(ctx, UserAddress, false) - server.Store.AddProject(ctx, ProjectID) + server.Store.AddProject(ctx, ProjectID, nil) server.Store.SetAccessLimit(ctx, ProjectID, &limit) server.Store.SetUserPermission(ctx, ProjectID, WalletAddress, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: ProjectID}) server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: AccessKey, ProjectID: ProjectID}) @@ -778,8 +787,14 @@ func TestSessionDisabled(t *testing.T) { case proto.SessionType_AccessKey, proto.SessionType_Project: assert.Equal(t, quotaRPM, rateLimit) assert.Equal(t, quotaLimit, h.Get(middleware.HeaderQuotaLimit)) - case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User: + case proto.SessionType_Wallet, proto.SessionType_User: assert.Equal(t, accountRPM, rateLimit) + case proto.SessionType_Admin: + limit := accountRPM + if tc.AccessKey != "" { + limit = quotaRPM + } + assert.Equal(t, limit, rateLimit) case proto.SessionType_InternalService: assert.Equal(t, serviceRPM, rateLimit) } diff --git a/middleware/middleware_quota.go b/middleware/middleware_quota.go index b9a25c5..f3cee9c 100644 --- a/middleware/middleware_quota.go +++ b/middleware/middleware_quota.go @@ -8,6 +8,7 @@ import ( "github.com/0xsequence/quotacontrol/proto" ) +// VerifyQuota middleware fetches and verify the quota from access key or project ID. func VerifyQuota(client Client, o Options) func(next http.Handler) http.Handler { o.ApplyDefaults() @@ -58,39 +59,37 @@ func VerifyQuota(client Client, o Options) func(next http.Handler) http.Handler } // fetch and verify access key quota - if session.Is(proto.SessionType_AccessKey, proto.SessionType_Project) { - accessKey, ok := authcontrol.GetAccessKey(ctx) - if !ok && session == proto.SessionType_AccessKey { - o.ErrHandler(r, w, proto.ErrUnauthorizedUser.WithCausef("verify quota: no access key found in context")) - return - } + accessKey, ok := authcontrol.GetAccessKey(ctx) + if !ok && session == proto.SessionType_AccessKey { + o.ErrHandler(r, w, proto.ErrUnauthorizedUser.WithCausef("verify quota: no access key found in context")) + return + } - if ok { - // check that project ID matches - if projectID != 0 { - if v, _ := proto.GetProjectID(accessKey); v != projectID { - o.ErrHandler(r, w, proto.ErrAccessKeyMismatch) - return - } + if ok { + // check that project ID matches + if projectID != 0 { + if v, _ := proto.GetProjectID(accessKey); v != projectID { + o.ErrHandler(r, w, proto.ErrAccessKeyMismatch) + return } + } - // fetch and verify access key quota - q, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), now) - if err != nil { - o.ErrHandler(r, w, err) + // fetch and verify access key quota + q, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), now) + if err != nil { + o.ErrHandler(r, w, err) + return + } + if q != nil { + if !q.IsActive() { + o.ErrHandler(r, w, proto.ErrAccessKeyNotFound) return } - if q != nil { - if !q.IsActive() { - o.ErrHandler(r, w, proto.ErrAccessKeyNotFound) - return - } - if quota != nil && quota.AccessKey.ProjectID != q.AccessKey.ProjectID { - o.ErrHandler(r, w, proto.ErrAccessKeyMismatch) - return - } - quota = q + if quota != nil && quota.AccessKey.ProjectID != q.AccessKey.ProjectID { + o.ErrHandler(r, w, proto.ErrAccessKeyMismatch) + return } + quota = q } } diff --git a/middleware/middleware_ratelimit.go b/middleware/middleware_ratelimit.go index ca30c40..2e01883 100644 --- a/middleware/middleware_ratelimit.go +++ b/middleware/middleware_ratelimit.go @@ -54,6 +54,7 @@ func (r RateLimitConfig) GetRateLimit(ctx context.Context, baseRequestCost int) return r.PublicRPM * baseRequestCost } +// RateLimit is a middleware that limits the number of requests per minute. func RateLimit(cfg RateLimitConfig, counter httprate.LimitCounter, o Options) func(next http.Handler) http.Handler { if !cfg.Enabled { return func(next http.Handler) http.Handler { diff --git a/tests/mock/mem.go b/tests/mock/mem.go index 9676994..fc83f0e 100644 --- a/tests/mock/mem.go +++ b/tests/mock/mem.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/0xsequence/authcontrol" "github.com/0xsequence/quotacontrol/internal/store" "github.com/0xsequence/quotacontrol/internal/usage" "github.com/0xsequence/quotacontrol/proto" @@ -17,7 +18,7 @@ func NewMemoryStore() *MemoryStore { accessKeys: map[string]proto.AccessKey{}, usage: usage.NewRecord(), users: map[string]bool{}, - projects: map[uint64]struct{}{}, + projects: map[uint64]*authcontrol.Auth{}, permissions: map[uint64]map[string]userPermission{}, } } @@ -35,7 +36,7 @@ type MemoryStore struct { accessKeys map[string]proto.AccessKey usage usage.Record users map[string]bool - projects map[uint64]struct{} + projects map[uint64]*authcontrol.Auth permissions map[uint64]map[string]userPermission } @@ -187,21 +188,21 @@ func (m *MemoryStore) GetUser(ctx context.Context, userID string) (any, bool, er return struct{}{}, v, nil } -func (m *MemoryStore) AddProject(ctx context.Context, projectID uint64) error { +func (m *MemoryStore) AddProject(ctx context.Context, projectID uint64, auth *authcontrol.Auth) error { m.Lock() - m.projects[projectID] = struct{}{} + m.projects[projectID] = auth m.Unlock() return nil } -func (m *MemoryStore) GetProject(ctx context.Context, projectID uint64) (any, error) { +func (m *MemoryStore) GetProject(ctx context.Context, projectID uint64) (any, *authcontrol.Auth, error) { m.Lock() - _, ok := m.projects[projectID] + auth, ok := m.projects[projectID] m.Unlock() if !ok { - return nil, nil + return nil, nil, nil } - return struct{}{}, nil + return struct{}{}, auth, nil } func (m *MemoryStore) GetUserPermission(ctx context.Context, projectID uint64, userID string) (proto.UserPermission, *proto.ResourceAccess, error) {