diff --git a/cert_error_go119.go b/cert_error_go119.go new file mode 100644 index 0000000..b2b27e8 --- /dev/null +++ b/cert_error_go119.go @@ -0,0 +1,14 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !go1.20 +// +build !go1.20 + +package retryablehttp + +import "crypto/x509" + +func isCertError(err error) bool { + _, ok := err.(x509.UnknownAuthorityError) + return ok +} diff --git a/cert_error_go120.go b/cert_error_go120.go new file mode 100644 index 0000000..a3cd315 --- /dev/null +++ b/cert_error_go120.go @@ -0,0 +1,14 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build go1.20 +// +build go1.20 + +package retryablehttp + +import "crypto/tls" + +func isCertError(err error) bool { + _, ok := err.(*tls.CertificateVerificationError) + return ok +} diff --git a/client.go b/client.go index f69063c..12ac50b 100644 --- a/client.go +++ b/client.go @@ -27,7 +27,6 @@ package retryablehttp import ( "bytes" "context" - "crypto/x509" "fmt" "io" "log" @@ -62,6 +61,10 @@ var ( // limit the size we consume to respReadLimit. respReadLimit = int64(4096) + // timeNow sets the function that returns the current time. + // This defaults to time.Now. Changes to this should only be done in tests. + timeNow = time.Now + // A regular expression to match the error returned by net/http when the // configured number of redirects is exhausted. This error isn't typed // specifically so we resort to matching on the error string. @@ -72,9 +75,10 @@ var ( // specifically so we resort to matching on the error string. schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) - // timeNow sets the function that returns the current time. - // This defaults to time.Now. Changes to this should only be done in tests. - timeNow = time.Now + // A regular expression to match the error returned by net/http when a + // request header or value is invalid. This error isn't typed + // specifically so we resort to matching on the error string. + invalidHeaderErrorRe = regexp.MustCompile(`invalid header`) // A regular expression to match the error returned by net/http when the // TLS certificate is not trusted. This error isn't typed @@ -501,11 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) { return false, v } + // Don't retry if the error was due to an invalid header. + if invalidHeaderErrorRe.MatchString(v.Error()) { + return false, v + } + // Don't retry if the error was due to TLS cert verification failure. if notTrustedErrorRe.MatchString(v.Error()) { return false, v } - if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + if isCertError(v.Err) { return false, v } } diff --git a/client_test.go b/client_test.go index 3cba8ec..3c37ef6 100644 --- a/client_test.go +++ b/client_test.go @@ -935,6 +935,60 @@ func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) { } } +func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.Header.Set("Header-Name-\033", "header value") + _, err = client.StandardClient().Do(req) + if err == nil { + t.Fatalf("expected header error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + +func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.Header.Set("Header-Name", "bad header value \033") + _, err = client.StandardClient().Do(req) + if err == nil { + t.Fatalf("expected header value error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + func TestClient_CheckRetryStop(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "test_500_body", http.StatusInternalServerError)