Skip to content

Commit

Permalink
Feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Mar 6, 2024
1 parent 51e2a91 commit f2216ad
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 31 deletions.
4 changes: 1 addition & 3 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2423,9 +2423,7 @@ func TestClientDisconnect(t *testing.T) {
assert.NotNil(t, err)
<-gotResponse
assert.NotNil(t, handlerReceiveErr)
if !assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled) {
t.Logf("handlerReceiveErr: %v", handlerReceiveErr)
}
assert.Equal(t, connect.CodeOf(handlerReceiveErr), connect.CodeCanceled, assert.Sprintf("got %v", handlerReceiveErr))
assert.ErrorIs(t, handlerContextErr, context.Canceled)
})
t.Run("handler_writes", func(t *testing.T) {
Expand Down
45 changes: 17 additions & 28 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ type ErrorWriter struct {
requireConnectProtocolHeader bool
}

// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
// RPC Content-Types in net/http middleware, you must pass the same
// HandlerOptions to NewErrorWriter and any wrapped Connect handlers.
// NewErrorWriter constructs an ErrorWriter. Handler options may be passed to
// configure the error writer behaviour to match the handlers.
// [WithRequiredConnectProtocolHeader] will assert that Connect protocol
// requests include the version header allowing the error writer to correctly
// classify the request.
// Options supplied via [WithConditionalHandlerOptions] are ignored.
func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
config := newHandlerConfig("", StreamTypeUnary, opts)
Expand All @@ -64,46 +66,33 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {

func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
method := request.Method
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case w.handleGRPC && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
if method != http.MethodPost {
break
}
case w.handleGRPC && isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case w.handleGRPCWeb && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
if method != http.MethodPost {
break
}
case w.handleGRPCWeb && isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
if method != http.MethodPost {
break
}
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
break
return unknownProtocol
}
return connectStreamProtocol
case strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if method != http.MethodPost {
break
}
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
break
return unknownProtocol
}
return connectUnaryProtocol
case ctype == "":
if method != http.MethodGet {
break
}
case isGet && ctype == "":
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
break
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
return unknownProtocol
}

// IsSupported checks whether a request is using one of the ErrorWriter's
Expand Down

0 comments on commit f2216ad

Please sign in to comment.