diff --git a/gzhttp/compress.go b/gzhttp/compress.go index 289ae3e2ee..a9be8e9793 100644 --- a/gzhttp/compress.go +++ b/gzhttp/compress.go @@ -306,7 +306,7 @@ func (w *GzipResponseWriter) startPlain() error { func (w *GzipResponseWriter) WriteHeader(code int) { // Handle informational headers // This is gated to not forward 1xx responses on builds prior to go1.20. - if shouldWrite1xxResponses() && code >= 100 && code <= 199 { + if code >= 100 && code <= 199 { w.ResponseWriter.WriteHeader(code) return } diff --git a/gzhttp/compress_go119.go b/gzhttp/compress_go119.go deleted file mode 100644 index 97fc25acbc..0000000000 --- a/gzhttp/compress_go119.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !go1.20 -// +build !go1.20 - -package gzhttp - -// shouldWrite1xxResponses indicates whether the current build supports writes of 1xx status codes. -func shouldWrite1xxResponses() bool { - return false -} diff --git a/gzhttp/compress_go120.go b/gzhttp/compress_go120.go deleted file mode 100644 index 2b65f67c79..0000000000 --- a/gzhttp/compress_go120.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build go1.20 -// +build go1.20 - -package gzhttp - -// shouldWrite1xxResponses indicates whether the current build supports writes of 1xx status codes. -func shouldWrite1xxResponses() bool { - return true -} diff --git a/gzhttp/compress_test.go b/gzhttp/compress_test.go index dde980b5a7..c91de81b49 100644 --- a/gzhttp/compress_test.go +++ b/gzhttp/compress_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" ) var ( @@ -1756,14 +1757,32 @@ func runBenchmark(b *testing.B, req *http.Request, handler http.Handler) { } func newTestHandler(body []byte) http.Handler { + var gzBuf bytes.Buffer + var zstdBuf bytes.Buffer + gz := gzip.NewWriter(&gzBuf) + gz.Write(body) + gz.Close() + zs, _ := zstd.NewWriter(&zstdBuf) + zs.Write(body) + zs.Close() return GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/gzipped": + // Add header. Write body as is. w.Header().Set("Content-Encoding", "gzip") w.Write(body) case "/zstd": + // Add header. Write body as is. w.Header().Set("Content-Encoding", "zstd") w.Write(body) + case "/gzipped/do": + // Add header. Write gzipped body. + w.Header().Set("Content-Encoding", "gzip") + w.Write(gzBuf.Bytes()) + case "/zstd/do": + // Add header. Write zstd body. + w.Header().Set("Content-Encoding", "zstd") + w.Write(zstdBuf.Bytes()) default: w.Write(body) } @@ -1803,11 +1822,6 @@ func TestGzipHandlerNilContentType(t *testing.T) { // This test is an adapted version of net/http/httputil.Test1xxResponses test. func Test1xxResponses(t *testing.T) { - // do not test 1xx responses on builds prior to go1.20. - if !shouldWrite1xxResponses() { - return - } - wrapper, _ := NewWrapper() handler := wrapper(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/gzhttp/transport.go b/gzhttp/transport.go index 623aea2ed8..3914a06e01 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -61,10 +61,21 @@ func TransportCustomEval(fn func(header http.Header) bool) transportOption { } } +// TransportAlwaysDecompress will always decompress the response, +// regardless of whether we requested it or not. +// Default is false, which will pass compressed data through +// if we did not request compression. +func TransportAlwaysDecompress(enabled bool) transportOption { + return func(c *gzRoundtripper) { + c.alwaysDecomp = enabled + } +} + type gzRoundtripper struct { parent http.RoundTripper acceptEncoding string withZstd, withGzip bool + alwaysDecomp bool customEval func(header http.Header) bool } @@ -90,15 +101,19 @@ func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { } resp, err := g.parent.RoundTrip(req) - if err != nil || !requestedComp { + if err != nil { return resp, err } - decompress := false + decompress := g.alwaysDecomp if g.customEval != nil { if !g.customEval(resp.Header) { return resp, nil } decompress = true + } else { + if !requestedComp && !g.alwaysDecomp { + return resp, nil + } } // Decompress if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") { diff --git a/gzhttp/transport_test.go b/gzhttp/transport_test.go index aff7edb4cf..72cb80c10b 100644 --- a/gzhttp/transport_test.go +++ b/gzhttp/transport_test.go @@ -251,6 +251,71 @@ func TestTransportCustomEval(t *testing.T) { } } +func TestTransportTransportAlwaysDecompress(t *testing.T) { + bin, err := os.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + + // We will serve the data as zstd+gzip, but the client will not request it. + server := httptest.NewServer(newTestHandler(bin)) + c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportEnableGzip(false), TransportAlwaysDecompress(true))} + resp, err := c.Get(server.URL + "/zstd/do") + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, bin) { + t.Errorf("data mismatch") + } + resp.Body.Close() + + resp, err = c.Get(server.URL + "/gzip/do") + if err != nil { + t.Fatal(err) + } + got, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, bin) { + t.Errorf("data mismatch") + } + resp.Body.Close() + + // We will serve the data as zstd+gzip, but the client will not request it. + // With TransportAlwaysDecompress(false) it should not be decompressed. + c = http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportEnableGzip(false), TransportAlwaysDecompress(false))} + resp, err = c.Get(server.URL + "/zstd/do") + if err != nil { + t.Fatal(err) + } + got, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(got, bin) { + t.Errorf("data matches") + } + resp.Body.Close() + + resp, err = c.Get(server.URL + "/gzip/do") + if err != nil { + t.Fatal(err) + } + got, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, bin) { + t.Errorf("data matches") + } + resp.Body.Close() +} + func BenchmarkTransport(b *testing.B) { raw, err := os.ReadFile("testdata/benchmark.json") if err != nil {