From fb57fd8eb06a354099445d944638c161489e632c Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Wed, 17 Apr 2024 09:24:05 -0700 Subject: [PATCH] httpbp: Fix Retries middleware When we set GetBody in http.Request, it's expected that Body is also set, add special handling in Retries to make sure we also set Body when retrying when GetBody is also set before each retry attempt. Also always clone the request before each retry attempt to avoid some subtle errors, and skip the Retries middleware altogether if Body is set but GetBody is not. --- httpbp/client_middlewares.go | 30 ++++++- httpbp/client_middlewares_test.go | 127 +++++++++++++++++++++--------- 2 files changed, 115 insertions(+), 42 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index f621a826c..3944a1fab 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -2,7 +2,9 @@ package httpbp import ( "errors" + "fmt" "io" + "log/slog" "net/http" "strconv" "sync" @@ -200,16 +202,36 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware { // Retries provides a retry middleware by ensuring certain HTTP responses are // wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if // you are using Retries there is no need to also use ClientErrorWrapper. -func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware { +func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware { if len(retryOptions) == 0 { retryOptions = []retry.Option{retry.Attempts(1)} } return func(next http.RoundTripper) http.RoundTripper { + // include ClientErrorWrapper to ensure retry is applied for some HTTP 5xx + // responses + next = ClientErrorWrapper(maxErrorReadAhead)(next) + return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + if req.Body != nil && req.Body != http.NoBody && req.GetBody == nil { + slog.WarnContext( + req.Context(), + "Request comes with a Body but nil GetBody cannot be retried. httpbp.Retries middleware skipped.", + "req", req, + ) + return next.RoundTrip(req) + } + err = retrybp.Do(req.Context(), func() error { - // include ClientErrorWrapper to ensure retry is applied for - // some HTTP 5xx responses - resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req) + req = req.Clone(req.Context()) + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err) + } + req.Body = body + } + + resp, err = next.RoundTrip(req) if err != nil { return err } diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 426e74e35..bfd9efc59 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "sync" "sync/atomic" "testing" @@ -256,11 +257,22 @@ func TestClientErrorWrapper(t *testing.T) { }) } +func unwrapRetryErrors(err error) []error { + var errs interface { + error + + Unwrap() []error + } + if errors.As(err, &errs) { + return errs.Unwrap() + } + return []error{err} +} + func TestRetry(t *testing.T) { - t.Run("retry for timeout", func(t *testing.T) { - const timeout = time.Millisecond * 10 + t.Run("retry for HTTP 500", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(timeout * 10) + w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() @@ -274,36 +286,72 @@ func TestRetry(t *testing.T) { attempts = n + 1 }), )(http.DefaultTransport), - Timeout: timeout, } - _, err := client.Get(server.URL) + u, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse url %q: %v", server.URL, err) + } + req := &http.Request{ + Method: http.MethodPost, + URL: u, + + // Explicitly set Body to http.NoBody and GetBody to nil, + // This request should not cause Retries middleware to be skipped. + Body: http.NoBody, + GetBody: nil, + } + _, err = client.Do(req) if err == nil { t.Fatalf("expected error to be non-nil") } - expected := uint(1) + expected := uint(2) if attempts != expected { t.Errorf("expected %d, actual: %d", expected, attempts) } + errs := unwrapRetryErrors(err) + if len(errs) != int(expected) { + t.Errorf("Expected %d retry erros, got %+v", expected, errs) + } + for i, err := range errs { + var ce *ClientError + if errors.As(err, &ce) { + if got, want := ce.StatusCode, http.StatusInternalServerError; got != want { + t.Errorf("#%d: status got %d want %d", i, got, want) + } + } else { + t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err) + } + } }) - t.Run("retry for HTTP 500", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Run("retry POST+HTTPS request", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + expected := "{}" + got := string(b) + if got != expected { + t.Errorf("expected %q, got: %q", expected, got) + } + t.Logf("Full body: %q", got) w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() var attempts uint - client := &http.Client{ - Transport: Retries( - DefaultMaxErrorReadAhead, - retry.Attempts(2), - retry.OnRetry(func(n uint, err error) { - // set number of attempts to check if retries were attempted - attempts = n + 1 - }), - )(http.DefaultTransport), - } - _, err := client.Get(server.URL) + t.Log(server.URL) + client := server.Client() + client.Transport = Retries( + DefaultMaxErrorReadAhead, + retry.Attempts(2), + retry.OnRetry(func(n uint, err error) { + // set number of attempts to check if retries were attempted + attempts = n + 1 + }), + )(client.Transport) + _, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}")) if err == nil { t.Fatalf("expected error to be non-nil") } @@ -311,41 +359,44 @@ func TestRetry(t *testing.T) { if attempts != expected { t.Errorf("expected %d, actual: %d", expected, attempts) } + errs := unwrapRetryErrors(err) + if len(errs) != int(expected) { + t.Errorf("Expected %d retry erros, got %+v", expected, errs) + } + for i, err := range errs { + var ce *ClientError + if errors.As(err, &ce) { + if got, want := ce.StatusCode, http.StatusInternalServerError; got != want { + t.Errorf("#%d: status got %d want %d", i, got, want) + } + } else { + t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err) + } + } }) - t.Run("retry POST request", func(t *testing.T) { + t.Run("skip retry for wrongly constructed request", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b, err := io.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - expected := "{}" - got := string(b) - if got != expected { - t.Errorf("expected %q, got: %q", expected, got) - } w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() - var attempts uint client := &http.Client{ Transport: Retries( DefaultMaxErrorReadAhead, retry.Attempts(2), retry.OnRetry(func(n uint, err error) { - // set number of attempts to check if retries were attempted - attempts = n + 1 + t.Errorf("Retry not skipped. OnRetry called with (%d, %v)", n, err) }), )(http.DefaultTransport), } - _, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}")) - if err == nil { - t.Fatalf("expected error to be non-nil") + req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewBufferString("{}")) + if err != nil { + t.Fatalf("Failed to create http request: %v", err) } - expected := uint(2) - if attempts != expected { - t.Errorf("expected %d, actual: %d", expected, attempts) + req.GetBody = nil + if _, err := client.Do(req); err == nil { + t.Fatalf("expected error to be non-nil") } }) }