diff --git a/httputil/fake.go b/httputil/fake.go index 4cc316e8..2081add3 100644 --- a/httputil/fake.go +++ b/httputil/fake.go @@ -18,38 +18,48 @@ func NewFakeTransport() *FakeTransport { } } -// AddResponse stores a fake HTTP response for the given URL. -func (ft *FakeTransport) AddResponse(url string, status int, body string, headers map[string]string) { +func (ft *FakeTransport) responseCollection(url string) *responseCollection { if _, ok := ft.responses[url]; !ok { ft.responses[url] = &responseCollection{} } + return ft.responses[url] +} + +// AddResponse stores a fake HTTP response for the given URL. +func (ft *FakeTransport) AddResponse(url string, status int, body string, headers map[string]string) { + ft.responseCollection(url).Add(createResponse(status, body, headers), nil) +} + +// AddResponse stores a error for the given URL. +func (ft *FakeTransport) AddError(url string, err error) { + ft.responseCollection(url).Add(nil, err) - ft.responses[url].Add(createResponse(status, body, headers)) } // RoundTrip returns a prerecorded response to the given request, if one exists. Otherwise its response indicates 404 - not found. func (ft *FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { if responses, ok := ft.responses[req.URL.String()]; ok { - return responses.Next(), nil + return responses.Next() } return notFound(), nil } type responseCollection struct { - all []*http.Response + all []responseError next int } -func (rc *responseCollection) Add(resp *http.Response) { - rc.all = append(rc.all, resp) +func (rc *responseCollection) Add(resp *http.Response, err error) { + rc.all = append(rc.all, responseError{resp: resp, err: err}) } -func (rc *responseCollection) Next() *http.Response { +func (rc *responseCollection) Next() (*http.Response, error) { if rc.next >= len(rc.all) { - return notFound() + return notFound(), nil } rc.next++ - return rc.all[rc.next-1] + next := rc.all[rc.next-1] + return next.resp, next.err } func createResponse(status int, body string, headers map[string]string) *http.Response { @@ -71,3 +81,8 @@ func transformHeaders(original map[string]string) http.Header { func notFound() *http.Response { return createResponse(http.StatusNotFound, "", nil) } + +type responseError struct { + resp *http.Response + err error +} diff --git a/httputil/httputil.go b/httputil/httputil.go index 07eb8c59..804b8e52 100644 --- a/httputil/httputil.go +++ b/httputil/httputil.go @@ -89,16 +89,19 @@ func get(url, auth string) (*http.Response, error) { } client := &http.Client{Transport: DefaultTransport} deadline := RetryClock.Now().Add(MaxRequestDuration) - lastStatus := 0 + var lastFailure string for attempt := 0; attempt <= MaxRetries; attempt++ { res, err := client.Do(req) - // Do not retry on success and permanent/fatal errors - if err != nil || !shouldRetry(res) { + if !shouldRetry(res, err) { return res, err } - lastStatus = res.StatusCode - waitFor, err := getWaitPeriod(res, attempt) + if err == nil { + lastFailure = fmt.Sprintf("HTTP %d", res.StatusCode) + } else { + lastFailure = err.Error() + } + waitFor, err := getWaitPeriod(res, err, attempt) if err != nil { return nil, err } @@ -111,18 +114,25 @@ func get(url, auth string) (*http.Response, error) { RetryClock.Sleep(waitFor) } } - return nil, fmt.Errorf("unable to complete request to %s after %d retries. Most recent status: %d", url, MaxRetries, lastStatus) + return nil, fmt.Errorf("unable to complete request to %s after %d retries. Most recent failure: %s", url, MaxRetries, lastFailure) } -func shouldRetry(res *http.Response) bool { +func shouldRetry(res *http.Response, err error) bool { + // Retry if the client failed to speak HTTP. + if err != nil { + return true + } + // For HTTP: only retry on permanent/fatal errors. return res.StatusCode == 429 || (500 <= res.StatusCode && res.StatusCode <= 504) } -func getWaitPeriod(res *http.Response, attempt int) (time.Duration, error) { - // Check if the server told us when to retry - for _, header := range retryHeaders { - if value := res.Header[header]; len(value) > 0 { - return parseRetryHeader(value[0]) +func getWaitPeriod(res *http.Response, err error, attempt int) (time.Duration, error) { + if err == nil { + // If HTTP works, check if the server told us when to retry + for _, header := range retryHeaders { + if value := res.Header[header]; len(value) > 0 { + return parseRetryHeader(value[0]) + } } } // Let's just use exponential backoff: 1s + d1, 2s + d2, 4s + d3, 8s + d4 with dx being a random value in [0ms, 500ms] diff --git a/httputil/httputil_test.go b/httputil/httputil_test.go index c8f2257b..9ceb9fb8 100644 --- a/httputil/httputil_test.go +++ b/httputil/httputil_test.go @@ -1,6 +1,7 @@ package httputil import ( + "errors" "net/http" "strconv" "testing" @@ -93,6 +94,29 @@ func TestSuccessOnRetry(t *testing.T) { } } +func TestSuccessOnRetryNonHTTPError(t *testing.T) { + transport, clock := setUp() + + url := "http://foo" + want := "the_body" + transport.AddError(url, errors.New("boom!")) + transport.AddResponse(url, 200, want, nil) + body, _, err := ReadRemoteFile(url, "") + + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + got := string(body) + if got != want { + t.Fatalf("Expected body %q, but got %q", want, got) + } + + if clock.TimesSlept() != 1 { + t.Fatalf("Expected a single retry, not %d", clock.TimesSlept()) + } +} + func TestAllTriesFail(t *testing.T) { MaxRequestDuration = 100 * time.Second @@ -106,7 +130,7 @@ func TestAllTriesFail(t *testing.T) { } reason := err.Error() - expected := "could not fetch http://bar: unable to complete request to http://bar after 5 retries. Most recent status: 502" + expected := "could not fetch http://bar: unable to complete request to http://bar after 5 retries. Most recent failure: HTTP 502" if reason != expected { t.Fatalf("Expected request to fail with %q, but got %q", expected, reason) }