Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Lixia (Sylvia) Lei <lixlei@microsoft.com>
  • Loading branch information
Wwwsylvia committed Dec 5, 2024
1 parent 8b4f572 commit 3869935
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 17 deletions.
21 changes: 19 additions & 2 deletions internal/trace/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var (
)

// payloadSizeLimit limits the maximum size of the response body to be printed.
const payloadSizeLimit int64 = 4 * 1024 * 1024 // 4 MiB
const payloadSizeLimit int64 = 4 * 1024 // 4 KiB

// Transport is an http.RoundTripper that keeps track of the in-flight
// request and add hooks to report HTTP tracing events.
Expand Down Expand Up @@ -113,8 +113,13 @@ func logResponseBody(resp *http.Response) string {
if err != nil {
return fmt.Sprintf(" Error reading response body: %v", err)
}

// restore the body by concatenating the read body with the remaining body
resp.Body = io.NopCloser(io.MultiReader(bytes.NewReader(readBody), resp.Body))
closeFunc := resp.Body.Close
resp.Body = &readCloser{
Reader: io.MultiReader(bytes.NewReader(readBody), resp.Body),
closeFunc: closeFunc,
}

if len(readBody) == 0 {
return " Response body is empty"
Expand Down Expand Up @@ -142,3 +147,15 @@ func isPrintableContentType(contentType string) bool {
}
return false
}

// readCloser returns an io.ReadCloser that wraps an io.Reader and a
// close function.
type readCloser struct {
io.Reader
closeFunc func() error
}

// Close closes the readCloser.
func (rc *readCloser) Close() error {
return rc.closeFunc()
}
141 changes: 126 additions & 15 deletions internal/trace/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@ package trace

import (
"bytes"
"fmt"
"errors"
"io"
"net/http"
"testing"
)

var (
mockReadErr = errors.New("mock read error")
mockCloseErr = errors.New("mock close error")
)

func Test_isPrintableContentType(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -117,9 +122,10 @@ func Test_isPrintableContentType(t *testing.T) {

func Test_logResponseBody(t *testing.T) {
tests := []struct {
name string
resp *http.Response
want string
name string
resp *http.Response
want string
wantData []byte
}{
{
name: "Nil body",
Expand All @@ -130,7 +136,8 @@ func Test_logResponseBody(t *testing.T) {
want: " No response body to print",
},
{
name: "No body",
name: "No body",
wantData: nil,
resp: &http.Response{
Body: http.NoBody,
ContentLength: 100, // in case of HEAD response, the content length is set but the body is empty
Expand All @@ -139,7 +146,8 @@ func Test_logResponseBody(t *testing.T) {
want: " No response body to print",
},
{
name: "Empty body",
name: "Empty body",
wantData: []byte(""),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte(""))),
ContentLength: 0,
Expand All @@ -148,7 +156,8 @@ func Test_logResponseBody(t *testing.T) {
want: " Response body is empty",
},
{
name: "Unknown content length",
name: "Unknown content length",
wantData: []byte("whatever"),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("whatever"))),
ContentLength: -1,
Expand All @@ -157,7 +166,8 @@ func Test_logResponseBody(t *testing.T) {
want: "whatever",
},
{
name: "Non-printable content type",
name: "Non-printable content type",
wantData: []byte("binary data"),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("binary data"))),
ContentLength: 11,
Expand All @@ -166,7 +176,8 @@ func Test_logResponseBody(t *testing.T) {
want: " Response body of content type \"application/octet-stream\" is not printed",
},
{
name: "Body at the limit",
name: "Body at the limit",
wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit)),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit)))),
ContentLength: payloadSizeLimit,
Expand All @@ -175,7 +186,8 @@ func Test_logResponseBody(t *testing.T) {
want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))),
},
{
name: "Body larger than limit",
name: "Body larger than limit",
wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit
ContentLength: payloadSizeLimit + 1,
Expand All @@ -184,7 +196,8 @@ func Test_logResponseBody(t *testing.T) {
want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)",
},
{
name: "Printable content type within limit",
name: "Printable content type within limit",
wantData: []byte("data"),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("data"))),
ContentLength: 4,
Expand All @@ -193,7 +206,8 @@ func Test_logResponseBody(t *testing.T) {
want: "data",
},
{
name: "Actual body size is larger than content length",
name: "Actual body size is larger than content length",
wantData: []byte("data"),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("data"))),
ContentLength: 3, // mismatched content length
Expand All @@ -202,7 +216,8 @@ func Test_logResponseBody(t *testing.T) {
want: "data",
},
{
name: "Actual body size is larger than content length and exceeds limit",
name: "Actual body size is larger than content length and exceeds limit",
wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit
ContentLength: 1, // mismatched content length
Expand All @@ -211,14 +226,45 @@ func Test_logResponseBody(t *testing.T) {
want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)",
},
{
name: "Actual body size is smaller than content length",
name: "Actual body size is smaller than content length",
wantData: []byte("data"),
resp: &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("data"))),
ContentLength: 5, // mismatched content length
Header: http.Header{"Content-Type": []string{"text/plain"}},
},
want: "data",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := logResponseBody(tt.resp); got != tt.want {
t.Errorf("logResponseBody() = %v, want %v", got, tt.want)
}
// validate the response body
if tt.resp.Body != nil {
readBytes, err := io.ReadAll(tt.resp.Body)
if err != nil {
t.Errorf("failed to read body after logResponseBody(), err= %v", err)
}
if !bytes.Equal(readBytes, tt.wantData) {
t.Errorf("resp.Body after logResponseBody() = %v, want %v", readBytes, tt.wantData)
}
if closeErr := tt.resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr)
}
}
})
}
}

func Test_logResponseBody_error(t *testing.T) {
tests := []struct {
name string
resp *http.Response
want string
}{
{
name: "Error reading body",
resp: &http.Response{
Expand All @@ -235,12 +281,77 @@ func Test_logResponseBody(t *testing.T) {
if got := logResponseBody(tt.resp); got != tt.want {
t.Errorf("logResponseBody() = %v, want %v", got, tt.want)
}
if closeErr := tt.resp.Body.Close(); closeErr != nil {
t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr)
}
})
}
}

func Test_readCloser_Close(t *testing.T) {

tests := []struct {
name string
reader io.Reader
closeFunc func() error
wantData []byte
wantReadErr error
wantCloseErr error
}{
{
name: "successfully read and close",
wantData: []byte("data"),
reader: bytes.NewReader([]byte("data")),
closeFunc: func() error {
return nil
},
wantReadErr: nil,
wantCloseErr: nil,
},
{
name: "error reading",
wantData: nil,
reader: &errorReader{},
closeFunc: func() error {
return nil
},
wantReadErr: mockReadErr,
wantCloseErr: nil,
},
{
name: "error closing",
wantData: []byte("data"),
reader: bytes.NewReader([]byte("data")),
closeFunc: func() error {
return mockCloseErr
},
wantReadErr: nil,
wantCloseErr: mockCloseErr,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rc := &readCloser{
Reader: tt.reader,
closeFunc: tt.closeFunc,
}
got, err := io.ReadAll(rc)
if err != tt.wantReadErr {
t.Errorf("readCloser.ReadAll() error = %v, wantErr %v", err, tt.wantReadErr)
}
if !bytes.Equal(got, tt.wantData) {
t.Errorf("readCloser.ReadAll() = %v, want %v", got, tt.wantData)
}
if err := rc.Close(); err != tt.wantCloseErr {
t.Errorf("readCloser.Close() error = %v, wantErr %v", err, tt.wantCloseErr)
}
})
}
}

type errorReader struct{}

func (e *errorReader) Read(p []byte) (n int, err error) {
return 0, fmt.Errorf("mock error")
return 0, mockReadErr
}

0 comments on commit 3869935

Please sign in to comment.