Skip to content

Commit

Permalink
httptransport: add concurrency limiter
Browse files Browse the repository at this point in the history
Signed-off-by: Hank Donnay <hdonnay@redhat.com>
  • Loading branch information
hdonnay committed Jan 14, 2022
1 parent 85b851e commit 29ad2ba
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
46 changes: 46 additions & 0 deletions httptransport/concurrentlimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package httptransport

import (
"net/http"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/sync/semaphore"
)

var concurrentLimitedCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricNamespace,
Subsystem: metricSubsystem,
Name: "concurrencylimited_total",
Help: "Total number of requests that have been concurrency limited.",
},
[]string{"endpoint", "method"},
)

// LimitHandler is a wrapper to help with concurrency limiting. This is slightly
// more complicated than the naive approach to allow for filtering on multiple
// aspects of the request.
//
// "Check" and "Next" need to be populated.
type limitHandler struct {
// The Check func inspects the request, and returns the semaphore to use and
// the endpoint to use in metrics. If a nil is returned, the request is
// allowed.
Check func(*http.Request) (*semaphore.Weighted, string)
// Next is the Handler to forward requests to, if allowed.
Next http.Handler
}

func (l *limitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sem, endpt := l.Check(r)
if sem != nil {
if !sem.TryAcquire(1) {
concurrentLimitedCounter.WithLabelValues(endpt, r.Method).Add(1)
apiError(w, http.StatusTooManyRequests, "server handling too many requests")
return
}
defer sem.Release(1)
}
l.Next.ServeHTTP(w, r)
}
76 changes: 76 additions & 0 deletions httptransport/concurrentlimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package httptransport

import (
"context"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"

"golang.org/x/sync/semaphore"
)

func TestConcurrentRequests(t *testing.T) {
sem := semaphore.NewWeighted(1)
// Ret controls when the http server returns.
// Ready is strobed once the first request is seen.
ret, ready := make(chan struct{}), make(chan struct{})
ct := new(int64)
var once sync.Once
srv := httptest.NewServer(&limitHandler{
Check: func(_ *http.Request) (*semaphore.Weighted, string) {
return sem, ""
},
Next: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
atomic.AddInt64(ct, 1)
once.Do(func() { close(ready) })
<-ret
w.WriteHeader(http.StatusNoContent)
}),
})
defer srv.Close()
c := srv.Client()

ctx := context.Background()
done := make(chan struct{})
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
if err != nil {
t.Fatal(err)
}
// Long-poll goroutine.
go func() {
defer close(done)
res, err := c.Do(req)
if err != nil {
t.Error(err)
return
}
defer res.Body.Close()
if got, want := res.StatusCode, http.StatusNoContent; got != want {
t.Errorf("got: %d, want: %d", got, want)
}
}()

// Wait for the above goroutine to hit the handler.
<-ready
for i := 0; i < 10; i++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
if err != nil {
t.Errorf("%d: %v", i, err)
}
res, err := c.Do(req)
if err != nil {
t.Errorf("%d: %v", i, err)
}
res.Body.Close()
if got, want := res.StatusCode, http.StatusTooManyRequests; got != want {
t.Errorf("got: %d, want: %d", got, want)
}
}
close(ret)
<-done
if got, want := *ct, int64(1); got != want {
t.Errorf("got: %d requests, want: %d requests", got, want)
}
}
2 changes: 2 additions & 0 deletions httptransport/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func apiError(w http.ResponseWriter, code int, f string, v ...interface{}) {
buf.WriteString("method-not-allowed")
case http.StatusNotFound:
buf.WriteString("not-found")
case http.StatusTooManyRequests:
buf.WriteString("too-many-requests")
default:
buf.WriteString("internal-error")
}
Expand Down

0 comments on commit 29ad2ba

Please sign in to comment.