Skip to content

Commit

Permalink
httpbp: Fix Retries middleware
Browse files Browse the repository at this point in the history
When we set GetBody in http.Request, it's expected that Body is also
set, add special handling in Retries to make sure we also set Body when
retrying when GetBody is also set before each retry attempt.

Also always clone the request before each retry attempt to avoid some
subtle errors, and skip the Retries middleware altogether if Body is set
but GetBody is not.
  • Loading branch information
fishy committed Apr 18, 2024
1 parent 5eb5e90 commit fb57fd8
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 42 deletions.
30 changes: 26 additions & 4 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package httpbp

import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -200,16 +202,36 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware {
// Retries provides a retry middleware by ensuring certain HTTP responses are
// wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if
// you are using Retries there is no need to also use ClientErrorWrapper.
func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware {
func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware {
if len(retryOptions) == 0 {
retryOptions = []retry.Option{retry.Attempts(1)}
}
return func(next http.RoundTripper) http.RoundTripper {
// include ClientErrorWrapper to ensure retry is applied for some HTTP 5xx
// responses
next = ClientErrorWrapper(maxErrorReadAhead)(next)

return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
if req.Body != nil && req.Body != http.NoBody && req.GetBody == nil {
slog.WarnContext(
req.Context(),
"Request comes with a Body but nil GetBody cannot be retried. httpbp.Retries middleware skipped.",
"req", req,
)
return next.RoundTrip(req)
}

err = retrybp.Do(req.Context(), func() error {
// include ClientErrorWrapper to ensure retry is applied for
// some HTTP 5xx responses
resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req)
req = req.Clone(req.Context())
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err)
}
req.Body = body
}

resp, err = next.RoundTrip(req)
if err != nil {
return err
}
Expand Down
127 changes: 89 additions & 38 deletions httpbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -256,11 +257,22 @@ func TestClientErrorWrapper(t *testing.T) {
})
}

func unwrapRetryErrors(err error) []error {
var errs interface {
error

Unwrap() []error
}
if errors.As(err, &errs) {
return errs.Unwrap()
}
return []error{err}
}

func TestRetry(t *testing.T) {
t.Run("retry for timeout", func(t *testing.T) {
const timeout = time.Millisecond * 10
t.Run("retry for HTTP 500", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(timeout * 10)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

Expand All @@ -274,78 +286,117 @@ func TestRetry(t *testing.T) {
attempts = n + 1
}),
)(http.DefaultTransport),
Timeout: timeout,
}
_, err := client.Get(server.URL)
u, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("Failed to parse url %q: %v", server.URL, err)
}
req := &http.Request{
Method: http.MethodPost,
URL: u,

// Explicitly set Body to http.NoBody and GetBody to nil,
// This request should not cause Retries middleware to be skipped.
Body: http.NoBody,
GetBody: nil,
}
_, err = client.Do(req)
if err == nil {
t.Fatalf("expected error to be non-nil")
}
expected := uint(1)
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
}
errs := unwrapRetryErrors(err)
if len(errs) != int(expected) {
t.Errorf("Expected %d retry erros, got %+v", expected, errs)
}
for i, err := range errs {
var ce *ClientError
if errors.As(err, &ce) {
if got, want := ce.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("#%d: status got %d want %d", i, got, want)
}
} else {
t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err)
}
}
})

t.Run("retry for HTTP 500", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Run("retry POST+HTTPS request", func(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
expected := "{}"
got := string(b)
if got != expected {
t.Errorf("expected %q, got: %q", expected, got)
}
t.Logf("Full body: %q", got)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

var attempts uint
client := &http.Client{
Transport: Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
}),
)(http.DefaultTransport),
}
_, err := client.Get(server.URL)
t.Log(server.URL)
client := server.Client()
client.Transport = Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
}),
)(client.Transport)
_, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}"))
if err == nil {
t.Fatalf("expected error to be non-nil")
}
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
}
errs := unwrapRetryErrors(err)
if len(errs) != int(expected) {
t.Errorf("Expected %d retry erros, got %+v", expected, errs)
}
for i, err := range errs {
var ce *ClientError
if errors.As(err, &ce) {
if got, want := ce.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("#%d: status got %d want %d", i, got, want)
}
} else {
t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err)
}
}
})

t.Run("retry POST request", func(t *testing.T) {
t.Run("skip retry for wrongly constructed request", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
expected := "{}"
got := string(b)
if got != expected {
t.Errorf("expected %q, got: %q", expected, got)
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

var attempts uint
client := &http.Client{
Transport: Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
t.Errorf("Retry not skipped. OnRetry called with (%d, %v)", n, err)
}),
)(http.DefaultTransport),
}
_, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}"))
if err == nil {
t.Fatalf("expected error to be non-nil")
req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewBufferString("{}"))
if err != nil {
t.Fatalf("Failed to create http request: %v", err)
}
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
req.GetBody = nil
if _, err := client.Do(req); err == nil {
t.Fatalf("expected error to be non-nil")
}
})
}
Expand Down

0 comments on commit fb57fd8

Please sign in to comment.