From 8d34f11435591682c1d540c5a13a49383b03d5ea Mon Sep 17 00:00:00 2001 From: iambenkay Date: Thu, 17 Dec 2020 16:17:53 +0100 Subject: [PATCH] Added last seen stats for visitor --- middleware/rate_limiter.go | 26 +++++--- middleware/rate_limiter_test.go | 106 ++++++++++++++++++++++---------- 2 files changed, 91 insertions(+), 41 deletions(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index f1c4c032e..4094ed1fa 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -5,6 +5,7 @@ import ( "golang.org/x/time/rate" "net/http" "sync" + "time" ) // RateLimiterStore is the interface to be implemented by custom stores. @@ -73,27 +74,34 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { } // RateLimiterMemoryStore is the built-in store implementation for RateLimiter -type RateLimiterMemoryStore struct { - visitors map[string]*rate.Limiter - mutex sync.Mutex - rate rate.Limit - burst int -} +type ( + RateLimiterMemoryStore struct { + visitors map[string]visitor + mutex sync.Mutex + rate rate.Limit + burst int + } + visitor struct { + *rate.Limiter + lastSeen time.Time + } +) // Allow implements RateLimiterStore.Allow func (store *RateLimiterMemoryStore) Allow(identifier string) bool { store.mutex.Lock() if store.visitors == nil { - store.visitors = make(map[string]*rate.Limiter) + store.visitors = make(map[string]visitor) } limiter, exists := store.visitors[identifier] if !exists { - limiter = rate.NewLimiter(store.rate, store.burst) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + limiter.lastSeen = time.Now() store.visitors[identifier] = limiter } - + limiter.lastSeen = time.Now() store.mutex.Unlock() return limiter.Allow() } diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index c0ac2d77d..8c727c6b7 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -119,46 +119,88 @@ func TestRateLimiter(t *testing.T) { } func TestRateLimiterWithConfig(t *testing.T) { - var inMemoryStore = new(RateLimiterMemoryStore) - inMemoryStore.rate = 1 - inMemoryStore.burst = 3 - - e := echo.New() + { + var inMemoryStore = new(RateLimiterMemoryStore) + inMemoryStore.rate = 1 + inMemoryStore.burst = 3 - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", 200}, + {"127.0.0.1", 200}, + {"127.0.0.1", 200}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + mw := RateLimiterWithConfig(RateLimiterConfig{ + SourceFunc: func(c echo.Context) string { + return c.RealIP() + }, + Store: inMemoryStore, + }) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } } + { + var inMemoryStore = new(RateLimiterMemoryStore) + inMemoryStore.rate = 1 + inMemoryStore.burst = 3 - testCases := []struct { - id string - code int - }{ - {"127.0.0.1", 200}, - {"127.0.0.1", 200}, - {"127.0.0.1", 200}, - {"127.0.0.1", 429}, - {"127.0.0.1", 429}, - {"127.0.0.1", 429}, - {"127.0.0.1", 429}, - } + e := echo.New() - for _, tc := range testCases { - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Add(echo.HeaderXRealIP, tc.id) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } - rec := httptest.NewRecorder() + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", 200}, + {"127.0.0.1", 200}, + {"127.0.0.1", 200}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + {"127.0.0.1", 429}, + } - c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ - SourceFunc: func(c echo.Context) string { - return c.RealIP() - }, - Store: inMemoryStore, - }) + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) - _ = mw(handler)(c) + rec := httptest.NewRecorder() - assert.Equal(t, tc.code, rec.Code) + c := e.NewContext(req, rec) + mw := RateLimiterWithConfig(RateLimiterConfig{ + Store: inMemoryStore, + }) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } } }