diff --git a/client.go b/client.go index fdc83a0..7e4f976 100644 --- a/client.go +++ b/client.go @@ -553,6 +553,11 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo // Do wraps calling an HTTP method with retries. func (c *Client) Do(req *Request) (*http.Response, error) { + return c.DoWithResponseHandler(req, nil) +} + +// DoWithResponseHandler wraps calling an HTTP method plus a response handler with retries. +func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) { c.clientInit.Do(func() { if c.HTTPClient == nil { c.HTTPClient = cleanhttp.DefaultPooledClient() @@ -606,9 +611,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // Attempt the request resp, doErr = c.HTTPClient.Do(req.Request) - // Check if we should continue with retries. - shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) - if doErr != nil { switch v := logger.(type) { case LeveledLogger: @@ -632,6 +634,13 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + // Check if we should continue with retries. + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) + + successSoFar := !shouldRetry && doErr == nil && checkErr == nil + if successSoFar && handler != nil { + shouldRetry = handler(resp) + } if !shouldRetry { break } @@ -739,6 +748,16 @@ func (c *Client) Get(url string) (*http.Response, error) { return c.Do(req) } +// GetWithResponseHandler is a helper for doing a GET request followed by a function on the response. +// The intention is for this to be used when errors in the response handling should also be retried. +func (c *Client) GetWithResponseHandler(url string, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.DoWithResponseHandler(req, handler) +} + // Head is a shortcut for doing a HEAD request without making a new client. func Head(url string) (*http.Response, error) { return defaultClient.Head(url) diff --git a/client_test.go b/client_test.go index 0c80e4c..04f3f5e 100644 --- a/client_test.go +++ b/client_test.go @@ -254,6 +254,83 @@ func testClientDo(t *testing.T, body interface{}) { } } +func TestClient_DoWithHandler(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 2 + + var attempts int + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + alternatingBool := false + tests := []struct { + name string + handler func(*http.Response) bool + expectedAttempts int + err string + }{ + { + name: "nil handler", + handler: nil, + expectedAttempts: 1, + }, + { + name: "handler never should retry", + handler: func(*http.Response) bool { return false }, + expectedAttempts: 1, + }, + { + name: "handler alternates should retry", + handler: func(*http.Response) bool { + alternatingBool = !alternatingBool + return alternatingBool + }, + expectedAttempts: 2, + }, + { + name: "handler always should retry", + handler: func(*http.Response) bool { return true }, + expectedAttempts: 3, + err: "giving up after 3 attempt(s)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + attempts = 0 + // Create the request + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Send the request. + _, err = client.DoWithResponseHandler(req, tt.handler) + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) + } + if err == nil && tt.err != "" { + t.Fatalf("no error, expected: %s", tt.err) + } + + if attempts != tt.expectedAttempts { + t.Fatalf("expected %d attempts, got %d attempts", tt.expectedAttempts, attempts) + } + }) + } +} + func TestClient_Do_fails(t *testing.T) { // Mock server which always responds 500. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {