diff --git a/http/sentryhttp.go b/http/sentryhttp.go index 64a7871a7..edbdb9f2f 100644 --- a/http/sentryhttp.go +++ b/http/sentryhttp.go @@ -63,22 +63,6 @@ func New(options Options) *Handler { } } -// responseWriter is a wrapper around http.ResponseWriter that captures the status code. -type responseWriter struct { - http.ResponseWriter - statusCode int -} - -// WriteHeader captures the status code and calls the original WriteHeader method. -func (rw *responseWriter) WriteHeader(code int) { - rw.statusCode = code - rw.ResponseWriter.WriteHeader(code) -} - -func newResponseWriter(w http.ResponseWriter) *responseWriter { - return &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} -} - // Handle works as a middleware that wraps an existing http.Handler. A wrapped // handler will recover from and report panics to Sentry, and provide access to // a request-specific hub to report messages and errors. @@ -122,10 +106,10 @@ func (h *Handler) handle(handler http.Handler) http.HandlerFunc { ) transaction.SetData("http.request.method", r.Method) - rw := newResponseWriter(w) + rw := NewWrapResponseWriter(w, r.ProtoMajor) defer func() { - status := rw.statusCode + status := rw.Status() transaction.Status = sentry.HTTPtoSpanStatus(status) transaction.SetData("http.response.status_code", status) transaction.Finish() diff --git a/http/wrap_writer.go b/http/wrap_writer.go new file mode 100644 index 000000000..053c33cbf --- /dev/null +++ b/http/wrap_writer.go @@ -0,0 +1,199 @@ +package sentryhttp + +import ( + "bufio" + "io" + "net" + "net/http" +) + +// This wrapper is derived from https://github.com/go-chi/chi/blob/master/middleware/wrap_writer.go +// Copyright (c) 2015-present Peter Kieltyka (https://github.com/pkieltyka), Google Inc. + +// MIT License + +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to +// hook into various parts of the response process. + +func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter { + _, fl := w.(http.Flusher) + + bw := basicWriter{ResponseWriter: w, code: http.StatusOK} + + if protoMajor == 2 { + _, ps := w.(http.Pusher) + if fl && ps { + return &http2FancyWriter{bw} + } + } else { + _, hj := w.(http.Hijacker) + _, rf := w.(io.ReaderFrom) + if fl && hj && rf { + return &httpFancyWriter{bw} + } + } + if fl { + return &flushWriter{bw} + } + + return &bw +} + +// WrapResponseWriter is a proxy around an http.ResponseWriter that allows you to hook +// into various parts of the response process. +type WrapResponseWriter interface { + http.ResponseWriter + // Status returns the HTTP status of the request, or 200 if one has not + // yet been sent. + Status() int + // BytesWritten returns the total number of bytes sent to the client. + BytesWritten() int + // Tee causes the response body to be written to the given io.Writer in + // addition to proxying the writes through. Only one io.Writer can be + // tee'd to at once: setting a second one will overwrite the first. + // Writes will be sent to the proxy before being written to this + // io.Writer. It is illegal for the tee'd writer to be modified + // concurrently with writes. + Tee(io.Writer) + // Unwrap returns the original proxied target. + Unwrap() http.ResponseWriter +} + +// basicWriter wraps a http.ResponseWriter that implements the minimal +// http.ResponseWriter interface. +type basicWriter struct { + http.ResponseWriter + tee io.Writer + code int + bytes int + wroteHeader bool +} + +func (b *basicWriter) WriteHeader(code int) { + if !b.wroteHeader { + b.code = code + b.wroteHeader = true + } + b.ResponseWriter.WriteHeader(code) +} + +func (b *basicWriter) Write(buf []byte) (int, error) { + b.maybeWriteHeader() + n, err := b.ResponseWriter.Write(buf) + if b.tee != nil { + _, err2 := b.tee.Write(buf[:n]) + // Prefer errors generated by the proxied writer. + if err == nil { + err = err2 + } + } + b.bytes += n + return n, err +} + +func (b *basicWriter) maybeWriteHeader() { + if !b.wroteHeader { + b.WriteHeader(http.StatusOK) + } +} + +func (b *basicWriter) Status() int { + return b.code +} + +func (b *basicWriter) BytesWritten() int { + return b.bytes +} + +func (b *basicWriter) Tee(w io.Writer) { + b.tee = w +} + +func (b *basicWriter) Unwrap() http.ResponseWriter { + return b.ResponseWriter +} + +type flushWriter struct { + basicWriter +} + +func (f *flushWriter) Flush() { + f.wroteHeader = true + + fl := f.basicWriter.ResponseWriter.(http.Flusher) + fl.Flush() +} + +var _ http.Flusher = &flushWriter{} + +// httpFancyWriter is a HTTP writer that additionally satisfies +// http.Flusher, http.Hijacker, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type httpFancyWriter struct { + basicWriter +} + +func (f *httpFancyWriter) Flush() { + f.wroteHeader = true + f.basicWriter.ResponseWriter.(http.Flusher).Flush() +} + +func (f *httpFancyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return f.basicWriter.ResponseWriter.(http.Hijacker).Hijack() +} + +func (f *http2FancyWriter) Push(target string, opts *http.PushOptions) error { + return f.basicWriter.ResponseWriter.(http.Pusher).Push(target, opts) +} + +func (f *httpFancyWriter) ReadFrom(r io.Reader) (int64, error) { + if f.basicWriter.tee != nil { + n, err := io.Copy(&f.basicWriter, r) + f.basicWriter.bytes += int(n) + return n, err + } + rf := f.basicWriter.ResponseWriter.(io.ReaderFrom) + f.basicWriter.maybeWriteHeader() + n, err := rf.ReadFrom(r) + f.basicWriter.bytes += int(n) + return n, err +} + +var _ http.Flusher = &httpFancyWriter{} +var _ http.Hijacker = &httpFancyWriter{} +var _ http.Pusher = &http2FancyWriter{} +var _ io.ReaderFrom = &httpFancyWriter{} + +// http2FancyWriter is a HTTP2 writer that additionally satisfies +// http.Flusher, and io.ReaderFrom. It exists for the common case +// of wrapping the http.ResponseWriter that package http gives you, in order to +// make the proxied object support the full method set of the proxied object. +type http2FancyWriter struct { + basicWriter +} + +func (f *http2FancyWriter) Flush() { + f.wroteHeader = true + + f.basicWriter.ResponseWriter.(http.Flusher).Flush() +} + +var _ http.Flusher = &http2FancyWriter{} diff --git a/http/wrap_writer_test.go b/http/wrap_writer_test.go new file mode 100644 index 000000000..057a22e3d --- /dev/null +++ b/http/wrap_writer_test.go @@ -0,0 +1,190 @@ +package sentryhttp + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +// CustomResponseWriter for testing http.Hijacker and http.Pusher. +type CustomResponseWriter struct { + *httptest.ResponseRecorder +} + +func (c *CustomResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("hijack not supported in tests") +} + +func (c *CustomResponseWriter) Push(string, *http.PushOptions) error { + return nil +} + +func (c *CustomResponseWriter) Flush() { + c.ResponseRecorder.Flush() +} + +func (c *CustomResponseWriter) ReadFrom(r io.Reader) (n int64, err error) { + buf := new(bytes.Buffer) + n, err = buf.ReadFrom(r) + if err == nil { + _, err = buf.WriteTo(c.ResponseRecorder) + } + return +} + +func TestHttpFancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { + f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: httptest.NewRecorder()}} + f.Flush() + + if !f.wroteHeader { + t.Fatal("want Flush to have set wroteHeader=true") + } +} + +func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) { + f := &http2FancyWriter{basicWriter{ResponseWriter: httptest.NewRecorder()}} + f.Flush() + + if !f.wroteHeader { + t.Fatal("want Flush to have set wroteHeader=true") + } +} + +func TestBytesWritten(t *testing.T) { + rec := httptest.NewRecorder() + bw := &basicWriter{ResponseWriter: rec} + + body := []byte("Hello, BytesWritten!") + _, err := bw.Write(body) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if bw.BytesWritten() != len(body) { + t.Fatalf("expected %v bytes written, got %v", len(body), bw.BytesWritten()) + } +} + +func TestUnwrap(t *testing.T) { + rec := httptest.NewRecorder() + bw := &basicWriter{ResponseWriter: rec} + + if bw.Unwrap() != rec { + t.Fatal("expected Unwrap to return the original ResponseWriter") + } +} + +func TestNewWrapResponseWriter(t *testing.T) { + rec := httptest.NewRecorder() + + // HTTP/1.1 request + w1 := NewWrapResponseWriter(rec, 1) + if _, ok := w1.(*flushWriter); !ok { + t.Fatalf("expected flushWriter, got %T", w1) + } + + // HTTP/2 request + customRec := &CustomResponseWriter{httptest.NewRecorder()} + w2 := NewWrapResponseWriter(customRec, 2) + if _, ok := w2.(*http2FancyWriter); !ok { + t.Fatalf("expected http2FancyWriter, got %T", w2) + } +} + +func TestBasicWriterWriteHeader(t *testing.T) { + rec := httptest.NewRecorder() + bw := &basicWriter{ResponseWriter: rec} + + bw.WriteHeader(http.StatusCreated) + if rec.Code != http.StatusCreated { + t.Fatalf("expected status code %v, got %v", http.StatusCreated, rec.Code) + } +} + +func TestBasicWriterWrite(t *testing.T) { + rec := httptest.NewRecorder() + bw := &basicWriter{ResponseWriter: rec} + + body := []byte("Hello, World!") + n, err := bw.Write(body) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if n != len(body) { + t.Fatalf("expected %v bytes written, got %v", len(body), n) + } + if rec.Body.String() != string(body) { + t.Fatalf("expected body %v, got %v", string(body), rec.Body.String()) + } + if bw.bytes != len(body) { + t.Fatalf("expected %v bytes written in struct, got %v", len(body), bw.bytes) + } +} + +func TestBasicWriterTee(t *testing.T) { + rec := httptest.NewRecorder() + var buf bytes.Buffer + bw := &basicWriter{ResponseWriter: rec} + + bw.Tee(&buf) + body := []byte("Hello, Tee!") + _, err := bw.Write(body) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if buf.String() != string(body) { + t.Fatalf("expected tee body %v, got %v", string(body), buf.String()) + } +} + +func TestFlushWriterFlush(t *testing.T) { + rec := httptest.NewRecorder() + fw := &flushWriter{basicWriter{ResponseWriter: rec}} + + fw.Flush() + if !fw.wroteHeader { + t.Fatal("want Flush to have set wroteHeader=true") + } +} + +func TestHttpFancyWriterHijack(t *testing.T) { + rec := &CustomResponseWriter{httptest.NewRecorder()} + + f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: rec}} + _, _, err := f.Hijack() + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestHttpFancyWriterReadFrom(t *testing.T) { + rec := &CustomResponseWriter{httptest.NewRecorder()} + f := &httpFancyWriter{basicWriter: basicWriter{ResponseWriter: rec}} + + body := []byte("Hello, ReadFrom!") + r := bytes.NewReader(body) + n, err := f.ReadFrom(r) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if int(n) != len(body) { + t.Fatalf("expected %v bytes read, got %v", len(body), n) + } + if rec.Body.String() != string(body) { + t.Fatalf("expected body %v, got %v", string(body), rec.Body.String()) + } +} + +func TestHttp2FancyWriterPush(t *testing.T) { + rec := &CustomResponseWriter{httptest.NewRecorder()} + + f := &http2FancyWriter{basicWriter: basicWriter{ResponseWriter: rec}} + err := f.Push("/some-path", nil) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +}