Skip to content

Commit

Permalink
Always look for quota (#72)
Browse files Browse the repository at this point in the history
* Always look for quota

* Update to latest AuthControl

* remove else

* change order
  • Loading branch information
klaidliadon authored Nov 19, 2024
1 parent b676583 commit 5efa2f5
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 45 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.23.2
// replace github.com/0xsequence/authcontrol => ../authcontrol

require (
github.com/0xsequence/authcontrol v0.2.1
github.com/0xsequence/authcontrol v0.3.1
github.com/0xsequence/go-sequence v0.43.0
github.com/alicebob/miniredis/v2 v2.33.0
github.com/go-chi/chi/v5 v5.1.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/0xsequence/authcontrol v0.2.1 h1:pu38FQqDb+RVUTxHMKw6dHgND7dqOQZ/Ur+sQ0vDp6w=
github.com/0xsequence/authcontrol v0.2.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8=
github.com/0xsequence/authcontrol v0.3.1 h1:2GS5WcmhtXrS3+qkAyGwQpnglijJZ7ZNnW6DS2qS6y0=
github.com/0xsequence/authcontrol v0.3.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8=
github.com/0xsequence/ethkit v1.28.0 h1:11p4UXXvYnixQk01+qmAcOF71N9DlSeMcEMbaCPtjaY=
github.com/0xsequence/ethkit v1.28.0/go.mod h1:rv0FAIyEyN0hhwGefbduAz4ujmyjyJXhCd6a0/yF3tk=
github.com/0xsequence/go-sequence v0.43.0 h1:PErMuTg4PeaamJutEJ6tAjrFBA8z0t6lvT9LOVC5RMs=
Expand Down
4 changes: 4 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ cloud.google.com/go/monitoring v1.21.0/go.mod h1:tuJ+KNDdJbetSsbSGTqnaBvbauS5kr3
cloud.google.com/go/storage v1.44.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE=
github.com/0xsequence/authcontrol v0.2.0 h1:7enDdLZSP3ngtMD7P6R6/yN1avgUFeHia0J8J8RKV5A=
github.com/0xsequence/authcontrol v0.2.0/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8=
github.com/0xsequence/authcontrol v0.3.0 h1:/dWx7sV2jds0imVAdFL4ixcQmm9GT0UDXPFRaVFHTe0=
github.com/0xsequence/authcontrol v0.3.0/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8=
github.com/0xsequence/authcontrol v0.3.1 h1:2GS5WcmhtXrS3+qkAyGwQpnglijJZ7ZNnW6DS2qS6y0=
github.com/0xsequence/authcontrol v0.3.1/go.mod h1:wicKJcmJYJU6jn07JBMqwxcAzyFm2BtfAzWpnL/40g8=
github.com/0xsequence/go-ethauth v0.13.0/go.mod h1:f3kx39S9F+W+qvZEB6bkKKbpUstmyB7goUntO3wvlhg=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1/go.mod h1:itPGVDKf9cC/ov4MdvJ2QZ0khw4bfoo9jzwTJlaxy2k=
Expand Down
29 changes: 22 additions & 7 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {
limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger)

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(authOptions))
r.Use(authcontrol.Session(authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(addCost(_credits * 2))
Expand Down Expand Up @@ -388,6 +389,7 @@ func TestJWT(t *testing.T) {
}

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(authOptions))
r.Use(authcontrol.Session(authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(middleware.EnsureUsage(client, quotaOptions))
Expand Down Expand Up @@ -467,6 +469,7 @@ func TestJWTAccess(t *testing.T) {
quotaOptions := middleware.Options{}

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(authOptions))
r.Use(authcontrol.Session(authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions))
Expand Down Expand Up @@ -573,17 +576,18 @@ func TestSession(t *testing.T) {
limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger)

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(authOptions))
r.Use(authcontrol.Session(authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(authcontrol.AccessControl(ACL, authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions))

r.Handle("/*", &counter)

ctx := context.Background()
limit := proto.Limit{RateLimit: 100, FreeWarn: 5, FreeMax: 5, OverWarn: 7, OverMax: 10}
server.Store.AddUser(ctx, UserAddress, false)
server.Store.AddProject(ctx, ProjectID)
server.Store.AddProject(ctx, ProjectID, nil)
server.Store.SetAccessLimit(ctx, ProjectID, &limit)
server.Store.SetUserPermission(ctx, ProjectID, WalletAddress, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: ProjectID})
server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: AccessKey, ProjectID: ProjectID})
Expand Down Expand Up @@ -664,10 +668,14 @@ func TestSession(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, quotaRPM, rateLimit)
assert.Equal(t, quotaLimit, h.Get(middleware.HeaderQuotaLimit))
case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User:
case proto.SessionType_Wallet, proto.SessionType_User:
assert.True(t, ok)
assert.NoError(t, err)
assert.Equal(t, accountRPM, rateLimit)
limit := accountRPM
if tc.AccessKey != "" {
limit = quotaRPM
}
assert.Equal(t, limit, rateLimit)
case proto.SessionType_InternalService:
assert.True(t, ok)
assert.NoError(t, err)
Expand Down Expand Up @@ -700,17 +708,18 @@ func TestSessionDisabled(t *testing.T) {
limitCounter := quotacontrol.NewLimitCounter(cfg.Redis, logger)

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(authOptions))
r.Use(authcontrol.Session(authOptions))
r.Use(authcontrol.AccessControl(ACL, authOptions))
r.Use(middleware.VerifyQuota(client, quotaOptions))
r.Use(middleware.RateLimit(cfg.RateLimiter, limitCounter, quotaOptions))
r.Use(authcontrol.AccessControl(ACL, authOptions))

r.Handle("/*", &counter)

ctx := context.Background()
limit := proto.Limit{RateLimit: 100, FreeWarn: 5, FreeMax: 5, OverWarn: 7, OverMax: 10}
server.Store.AddUser(ctx, UserAddress, false)
server.Store.AddProject(ctx, ProjectID)
server.Store.AddProject(ctx, ProjectID, nil)
server.Store.SetAccessLimit(ctx, ProjectID, &limit)
server.Store.SetUserPermission(ctx, ProjectID, WalletAddress, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: ProjectID})
server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: AccessKey, ProjectID: ProjectID})
Expand Down Expand Up @@ -778,8 +787,14 @@ func TestSessionDisabled(t *testing.T) {
case proto.SessionType_AccessKey, proto.SessionType_Project:
assert.Equal(t, quotaRPM, rateLimit)
assert.Equal(t, quotaLimit, h.Get(middleware.HeaderQuotaLimit))
case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User:
case proto.SessionType_Wallet, proto.SessionType_User:
assert.Equal(t, accountRPM, rateLimit)
case proto.SessionType_Admin:
limit := accountRPM
if tc.AccessKey != "" {
limit = quotaRPM
}
assert.Equal(t, limit, rateLimit)
case proto.SessionType_InternalService:
assert.Equal(t, serviceRPM, rateLimit)
}
Expand Down
53 changes: 26 additions & 27 deletions middleware/middleware_quota.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/0xsequence/quotacontrol/proto"
)

// VerifyQuota middleware fetches and verify the quota from access key or project ID.
func VerifyQuota(client Client, o Options) func(next http.Handler) http.Handler {
o.ApplyDefaults()

Expand Down Expand Up @@ -58,39 +59,37 @@ func VerifyQuota(client Client, o Options) func(next http.Handler) http.Handler
}

// fetch and verify access key quota
if session.Is(proto.SessionType_AccessKey, proto.SessionType_Project) {
accessKey, ok := authcontrol.GetAccessKey(ctx)
if !ok && session == proto.SessionType_AccessKey {
o.ErrHandler(r, w, proto.ErrUnauthorizedUser.WithCausef("verify quota: no access key found in context"))
return
}
accessKey, ok := authcontrol.GetAccessKey(ctx)
if !ok && session == proto.SessionType_AccessKey {
o.ErrHandler(r, w, proto.ErrUnauthorizedUser.WithCausef("verify quota: no access key found in context"))
return
}

if ok {
// check that project ID matches
if projectID != 0 {
if v, _ := proto.GetProjectID(accessKey); v != projectID {
o.ErrHandler(r, w, proto.ErrAccessKeyMismatch)
return
}
if ok {
// check that project ID matches
if projectID != 0 {
if v, _ := proto.GetProjectID(accessKey); v != projectID {
o.ErrHandler(r, w, proto.ErrAccessKeyMismatch)
return
}
}

// fetch and verify access key quota
q, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), now)
if err != nil {
o.ErrHandler(r, w, err)
// fetch and verify access key quota
q, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), now)
if err != nil {
o.ErrHandler(r, w, err)
return
}
if q != nil {
if !q.IsActive() {
o.ErrHandler(r, w, proto.ErrAccessKeyNotFound)
return
}
if q != nil {
if !q.IsActive() {
o.ErrHandler(r, w, proto.ErrAccessKeyNotFound)
return
}
if quota != nil && quota.AccessKey.ProjectID != q.AccessKey.ProjectID {
o.ErrHandler(r, w, proto.ErrAccessKeyMismatch)
return
}
quota = q
if quota != nil && quota.AccessKey.ProjectID != q.AccessKey.ProjectID {
o.ErrHandler(r, w, proto.ErrAccessKeyMismatch)
return
}
quota = q
}
}

Expand Down
1 change: 1 addition & 0 deletions middleware/middleware_ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (r RateLimitConfig) GetRateLimit(ctx context.Context, baseRequestCost int)
return r.PublicRPM * baseRequestCost
}

// RateLimit is a middleware that limits the number of requests per minute.
func RateLimit(cfg RateLimitConfig, counter httprate.LimitCounter, o Options) func(next http.Handler) http.Handler {
if !cfg.Enabled {
return func(next http.Handler) http.Handler {
Expand Down
17 changes: 9 additions & 8 deletions tests/mock/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"sync"
"time"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/quotacontrol/internal/store"
"github.com/0xsequence/quotacontrol/internal/usage"
"github.com/0xsequence/quotacontrol/proto"
Expand All @@ -17,7 +18,7 @@ func NewMemoryStore() *MemoryStore {
accessKeys: map[string]proto.AccessKey{},
usage: usage.NewRecord(),
users: map[string]bool{},
projects: map[uint64]struct{}{},
projects: map[uint64]*authcontrol.Auth{},
permissions: map[uint64]map[string]userPermission{},
}
}
Expand All @@ -35,7 +36,7 @@ type MemoryStore struct {
accessKeys map[string]proto.AccessKey
usage usage.Record
users map[string]bool
projects map[uint64]struct{}
projects map[uint64]*authcontrol.Auth
permissions map[uint64]map[string]userPermission
}

Expand Down Expand Up @@ -187,21 +188,21 @@ func (m *MemoryStore) GetUser(ctx context.Context, userID string) (any, bool, er
return struct{}{}, v, nil
}

func (m *MemoryStore) AddProject(ctx context.Context, projectID uint64) error {
func (m *MemoryStore) AddProject(ctx context.Context, projectID uint64, auth *authcontrol.Auth) error {
m.Lock()
m.projects[projectID] = struct{}{}
m.projects[projectID] = auth
m.Unlock()
return nil
}

func (m *MemoryStore) GetProject(ctx context.Context, projectID uint64) (any, error) {
func (m *MemoryStore) GetProject(ctx context.Context, projectID uint64) (any, *authcontrol.Auth, error) {
m.Lock()
_, ok := m.projects[projectID]
auth, ok := m.projects[projectID]
m.Unlock()
if !ok {
return nil, nil
return nil, nil, nil
}
return struct{}{}, nil
return struct{}{}, auth, nil
}

func (m *MemoryStore) GetUserPermission(ctx context.Context, projectID uint64, userID string) (proto.UserPermission, *proto.ResourceAccess, error) {
Expand Down

0 comments on commit 5efa2f5

Please sign in to comment.