Skip to content

Commit

Permalink
Reconcile trailers-only and misc error behavior with grpc-go (#690)
Browse files Browse the repository at this point in the history
This largely undoes a recent change to do more validation of
trailers-only responses (#685), which disallows a body or trailers in
what appeared to be a trailers-only response. In that change, a
trailers-only response was identified by the presence of a "grpc-status"
key in the headers.

In this PR, a trailers-only response is instead defined by the lack of
body and trailers (not the presence of a "grpc-status" header). This PR
also tweaks some other error scenarios:

* If trailers (or an end-stream message) is completely missing from a
response, it's considered an `internal` error. But if trailers are
present, but the "grpc-status" key is missing, it's considered an issue
determining the status, which is an `unknown` error.
* Similarly, if a response content-type doesn't appear to be the right
protocol (like it may have come from a non-RPC server), the error code
is now `unknown`. But if it looks like the right protocol but uses the
wrong sub-format/codec, it's an `internal` error.
* Note that in grpc-go, this behavior is also seen in the client, but
this PR doesn't attempt to address that in the connect-go client.
Instead, that change can be made when #689 is addressed.

This PR also now makes connect-go more strict about the "compressed"
flag in a streaming protocol when there was no compression algorithm
negotiated. Previously, this library was lenient and did not consider it
an error if the message in question was empty (zero bytes). But to
correctly adhere to gRPC semantics, it must report this case as an
`internal `error.
  • Loading branch information
jhump authored Feb 16, 2024
1 parent 468ef0d commit 2f76b54
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 206 deletions.
140 changes: 45 additions & 95 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ func TestGRPCMissingTrailersError(t *testing.T) {
var connectErr *connect.Error
ok := errors.As(err, &connectErr)
assert.True(t, ok)
assert.Equal(t, connectErr.Code(), connect.CodeInternal)
assert.Equal(t, connectErr.Code(), connect.CodeUnknown)
assert.True(
t,
strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"),
Expand Down Expand Up @@ -1838,7 +1838,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
t.Parallel()
assertIsFlusherErr := func(t *testing.T, err error) {
t.Helper()
assert.NotNil(t, err)
if !assert.NotNil(t, err) {
return
}
assert.Equal(t, connect.CodeOf(err), connect.CodeInternal, assert.Sprintf("got %v", err))
assert.True(
t,
Expand Down Expand Up @@ -1875,8 +1877,9 @@ func TestUnflushableResponseWriter(t *testing.T) {
assertIsFlusherErr(t, err)
return
}
assert.False(t, stream.Receive())
assertIsFlusherErr(t, stream.Err())
if assert.False(t, stream.Receive()) {
assertIsFlusherErr(t, stream.Err())
}
})
}
}
Expand Down Expand Up @@ -2146,6 +2149,21 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
responseWriter.Header().Set(http.TrailerPrefix+"grpc-message", "foo")
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
Expand All @@ -2159,6 +2177,29 @@ func TestStreamUnexpectedEOF(t *testing.T) {
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_status",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Trailers exist, just no status. So error will be unknown instead of internal.
_, err = responseWriter.Write([]byte{128}) // end-stream flag
assert.Nil(t, err)
endStream := "grpc-message: foo\r\n"
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(endStream)))
_, err = responseWriter.Write(length[:])
assert.Nil(t, err)
_, err = responseWriter.Write([]byte(endStream))
assert.Nil(t, err)
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "connect_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON()},
Expand Down Expand Up @@ -2442,97 +2483,6 @@ func TestClientDisconnect(t *testing.T) {
})
}

func TestTrailersOnlyErrors(t *testing.T) {
t.Parallel()

head := [3]byte{}
testcases := []struct {
name string
handler http.HandlerFunc
options []connect.ClientOption
expectCode connect.Code
expectMsg string
}{{
name: "grpc_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc-web_body_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after trailers-only response", len(head)),
}, {
name: "grpc_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}, {
name: "grpc-web_trailers_after_trailers-only",
options: []connect.ClientOption{connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web")
header.Set("Grpc-Status", "3")
responseWriter.WriteHeader(http.StatusOK)
responseWriter.(http.Flusher).Flush() //nolint:forcetypeassert
header.Set(http.TrailerPrefix+"Foo", "abc")
},
expectCode: connect.CodeInternal,
expectMsg: "internal: corrupt response from server: gRPC trailers-only response may not contain HTTP trailers",
}}
for _, testcase := range testcases {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = io.Copy(io.Discard, request.Body)
testcase.handler(responseWriter, request)
})
server := memhttptest.NewServer(t, mux)
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL(),
testcase.options...,
)
const upTo = 2
request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo})
request.Header().Set("Test-Case", t.Name())
stream, err := client.CountUp(context.Background(), request)
assert.Nil(t, err)
for i := 0; stream.Receive() && i < upTo; i++ {
assert.Equal(t, stream.Msg().GetNumber(), 42)
}
assert.NotNil(t, stream.Err())
assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode)
assert.Equal(t, stream.Err().Error(), testcase.expectMsg)
})
}
}

// TestBlankImportCodeGeneration tests that services.connect.go is generated with
// blank import statements to services.pb.go so that the service's Descriptor is
// available in the global proto registry.
Expand Down
27 changes: 17 additions & 10 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (w *envelopeWriter) write(env *envelope) *Error {
type envelopeReader struct {
ctx context.Context //nolint:containedctx
reader io.Reader
bytesRead int64 // detect trailers-only gRPC responses
codec Codec
last envelope
compressionPool *compressionPool
Expand All @@ -241,6 +242,11 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
return errorf(
CodeInternal,
"protocol error: sent compressed message without compression support",
)
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
Expand All @@ -257,12 +263,6 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

data := env.Data
if data.Len() > 0 && env.IsSet(flagEnvelopeCompressed) {
if r.compressionPool == nil {
return errorf(
CodeInvalidArgument,
"protocol error: sent compressed message without compression support",
)
}
decompressed := r.bufferPool.Get()
defer func() {
if decompressed != dontRelease {
Expand All @@ -277,7 +277,9 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if numBytes, err := discard(r.reader); err != nil {
numBytes, err := discard(r.reader)
r.bytesRead += numBytes
if err != nil {
err = wrapIfContextError(err)
if connErr, ok := asError(err); ok {
return connErr
Expand Down Expand Up @@ -308,7 +310,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
prefixes := [5]byte{}
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
n, err := io.ReadFull(r.reader, prefixes[:])
r.bytesRead += int64(n)
if err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
Expand All @@ -328,7 +332,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
n, err := io.CopyN(io.Discard, r.reader, size)
r.bytesRead += n
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
}
Expand All @@ -337,7 +342,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
// number of bytes.
if readN, err := io.CopyN(env.Data, r.reader, size); err != nil {
readN, err := io.CopyN(env.Data, r.reader, size)
r.bytesRead += readN
if err != nil {
if errors.Is(err, io.EOF) {
// We've gotten fewer bytes than we expected, so the stream has ended
// unexpectedly.
Expand Down
10 changes: 8 additions & 2 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ func DecodeBinaryHeader(data string) ([]byte, error) {
}

func mergeHeaders(into, from http.Header) {
for k, vals := range from {
into[k] = append(into[k], vals...)
for key, vals := range from {
if len(vals) == 0 {
// For response trailers, net/http will pre-populate entries
// with nil values based on the "Trailer" header. But if there
// are no actual values for those keys, we skip them.
continue
}
into[key] = append(into[key], vals...)
}
}

Expand Down
1 change: 0 additions & 1 deletion header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ func TestHeaderMerge(t *testing.T) {
expect := http.Header{
"Foo": []string{"one", "two"},
"Bar": []string{"one"},
"Baz": nil,
}
assert.Equal(t, header, expect)
}
18 changes: 18 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,15 @@ func connectValidateUnaryResponseContentType(
)
}
// Normal responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectUnaryContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
StreamTypeUnary,
responseContentType,
Expand All @@ -1410,6 +1419,15 @@ func connectValidateUnaryResponseContentType(

func connectValidateStreamResponseContentType(requestCodecName string, streamType StreamType, responseContentType string) *Error {
// Responses must have valid content-type that indicates same codec as the request.
if !strings.HasPrefix(responseContentType, connectStreamingContentTypePrefix) {
// Doesn't even look like a Connect response? Use code "unknown".
return errorf(
CodeUnknown,
"invalid content-type: %q; expecting %q",
responseContentType,
connectUnaryContentTypePrefix+requestCodecName,
)
}
responseCodecName := connectCodecFromContentType(
streamType,
responseContentType,
Expand Down
37 changes: 24 additions & 13 deletions protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func TestConnectValidateUnaryResponseContentType(t *testing.T) {
codecName: codecNameJSON,
statusCode: http.StatusOK,
responseContentType: "some/garbage",
expectCode: CodeInternal,
expectCode: CodeUnknown, // doesn't even look like it could be connect protocol
expectBadContentType: true,
},
// Error status, invalid content-type, returns code based on HTTP status code
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
testCases := []struct {
codecName string
responseContentType string
expectErr bool
expectCode Code
}{
// Allowed content-types
{
Expand All @@ -307,31 +307,42 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
codecName: codecNameJSON,
responseContentType: "application/connect+json",
},
// Mismatched response codec
{
codecName: codecNameProto,
responseContentType: "application/connect+json",
expectCode: CodeInternal,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+proto",
expectCode: CodeInternal,
},
// Disallowed content-types
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectCode: CodeInternal, // *almost* looks right
},
{
codecName: codecNameProto,
responseContentType: "application/proto",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameJSON,
responseContentType: "application/json; charset=utf-8",
expectErr: true,
},
{
codecName: codecNameJSON,
responseContentType: "application/connect+json; charset=utf-8",
expectErr: true,
expectCode: CodeUnknown,
},
{
codecName: codecNameProto,
responseContentType: "some/garbage",
expectErr: true,
expectCode: CodeUnknown,
},
}
for _, testCase := range testCases {
Expand All @@ -344,10 +355,10 @@ func TestConnectValidateStreamResponseContentType(t *testing.T) {
StreamTypeServer,
testCase.responseContentType,
)
if !testCase.expectErr {
if testCase.expectCode == 0 {
assert.Nil(t, err)
} else if assert.NotNil(t, err) {
assert.Equal(t, CodeOf(err), CodeInternal)
assert.Equal(t, CodeOf(err), testCase.expectCode)
assert.True(t, strings.Contains(err.Message(), fmt.Sprintf("invalid content-type: %q; expecting", testCase.responseContentType)))
}
})
Expand Down
Loading

0 comments on commit 2f76b54

Please sign in to comment.