-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
http: rate limit index report requests
Added a basic rate limiter that will keep concurrent connections limited and return a 429 if they go above. Signed-off-by: crozzy <joseph.crosland@gmail.com>
- Loading branch information
Showing
4 changed files
with
132 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package rate | ||
|
||
import ( | ||
"net/http" | ||
|
||
"github.com/prometheus/client_golang/prometheus" | ||
"github.com/prometheus/client_golang/prometheus/promauto" | ||
"github.com/quay/claircore/pkg/jsonerr" | ||
) | ||
|
||
var ( | ||
rateLimitedCounter = promauto.NewCounterVec( | ||
prometheus.CounterOpts{ | ||
Namespace: "clair", | ||
Subsystem: "http", | ||
Name: "ratelimited_total", | ||
Help: "Total number of requests that have been rate limited.", | ||
}, | ||
[]string{"endpoint"}, | ||
) | ||
) | ||
|
||
type RateLimitMiddleware struct { | ||
ctlChan chan struct{} | ||
} | ||
|
||
func NewRateLimitMiddleware(maxConcurrent int) *RateLimitMiddleware { | ||
rlm := &RateLimitMiddleware{} | ||
if maxConcurrent > 0 { | ||
rlm.ctlChan = make(chan struct{}, maxConcurrent) | ||
} | ||
return rlm | ||
} | ||
|
||
func (rlm *RateLimitMiddleware) Handler(endpoint string, next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// If the channel is never initialized just carry on. | ||
if rlm.ctlChan == nil { | ||
next.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// See if we have any space for a new request. | ||
select { | ||
case rlm.ctlChan <- struct{}{}: | ||
defer func() { <-rlm.ctlChan }() | ||
default: | ||
rateLimitedCounter.WithLabelValues(endpoint).Add(1) | ||
resp := &jsonerr.Response{ | ||
Code: "too-many-requests", | ||
Message: "server handling too many requests", | ||
} | ||
jsonerr.Error(w, resp, http.StatusTooManyRequests) | ||
return | ||
} | ||
next.ServeHTTP(w, r) | ||
// Bug (crozzy): this approach is a little rough and | ||
// will lock-out indexing for pre-indexed images, which | ||
// is not resource intensive. | ||
}) | ||
} | ||
|
||
func (rlm *RateLimitMiddleware) Close() { | ||
close(rlm.ctlChan) | ||
rlm.ctlChan = nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package rate | ||
|
||
import ( | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"golang.org/x/sync/errgroup" | ||
) | ||
|
||
func noOpHandler() http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// mimic some work | ||
time.Sleep(100 * time.Millisecond) | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
} | ||
|
||
func TestRateLimitMiddleWare(t *testing.T) { | ||
req, err := http.NewRequest("POST", "/indexer/api/v1/index_report", nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
rr := httptest.NewRecorder() | ||
rlm := NewRateLimitMiddleware(1) | ||
handler := rlm.Handler("", noOpHandler()) | ||
|
||
handler.ServeHTTP(rr, req) | ||
|
||
if rr.Code != http.StatusOK { | ||
t.Errorf("handler did not return successful response as expected") | ||
} | ||
|
||
rr.Flush() | ||
var failures uint64 | ||
|
||
eg := errgroup.Group{} | ||
for i := 0; i < 2; i++ { | ||
eg.Go(func() error { | ||
rr := httptest.NewRecorder() | ||
handler.ServeHTTP(rr, req) | ||
if rr.Code != http.StatusOK { | ||
atomic.AddUint64(&failures, 1) | ||
return fmt.Errorf("Got status code %d", rr.Code) | ||
} | ||
return nil | ||
}) | ||
} | ||
if err := eg.Wait(); err == nil || failures != 1 { | ||
t.Fatalf("test failed: expected one failure") | ||
} | ||
} |