Skip to content

Commit

Permalink
Add support for Connect-Protocol-Version header (#416)
Browse files Browse the repository at this point in the history
Currently, it's difficult for proxies, net/http middleware, and other
"in-between" code to distinguish unary Connect RPC requests from other
HTTP
traffic (especially with JSON payloads, which use the common
`application/json`
Content-Type). To work around this, any in-between code today must check
whether the HTTP path matches a known RPC method, which requires
intimate
knowledge of the service schema. This isn't great.

To work around this limitation, we've added an optional header to the
specification for Connect RPC requests: `Connect-Protocol-Version`.
Generated
clients always send this header, so the vast majority of traffic should
include
it. Servers _may_ require that requests include this header, which helps
proxies and net/http middleware identify _every_ Connect request. This
lets
in-between code function more reliably: for example, it could help a
metrics-collecting reverse proxy produce higher-fidelity statistics,
since it
wouldn't miss even the handful of RPCs made with ad-hoc cURL commands.
However,
it makes ad-hoc debugging with cURL or fetch slightly more laborious.

This PR makes clients using the Connect protocol send the
`Connect-Protocol-Version` header, and it allows servers to opt into
strict
validation by using the `WithRequireConnectProtocolHeader` option. It
makes no
changes to the behavior of the gRPC or gRPC-Web protocols.

Note that servers exposing Connect RPCs to web browsers may need to
update
their CORS configuration to allow the `Connect-Protocol-Version` header.
  • Loading branch information
akshayjshah authored Dec 8, 2022
1 parent ac83b0a commit 3e30c2d
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 25 deletions.
56 changes: 56 additions & 0 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,62 @@ func TestGRPCErrorMetadataIsTrailersOnly(t *testing.T) {
assert.NotZero(t, res.Trailer.Get(handlerTrailer))
}

func TestConnectProtocolHeaderSentByDefault(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer{}, connect.WithRequireConnectProtocolHeader()))
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

client := pingv1connect.NewPingServiceClient(server.Client(), server.URL)
_, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{}))
assert.Nil(t, err)

stream := client.CumSum(context.Background())
assert.Nil(t, stream.Send(&pingv1.CumSumRequest{}))
_, err = stream.Receive()
assert.Nil(t, err)
assert.Nil(t, stream.CloseRequest())
assert.Nil(t, stream.CloseResponse())
}

func TestConnectProtocolHeaderRequired(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(
pingServer{},
connect.WithRequireConnectProtocolHeader(),
))
server := httptest.NewServer(mux)
t.Cleanup(server.Close)

tests := []struct {
headers http.Header
}{
{http.Header{}},
{http.Header{"Connect-Protocol-Version": []string{"0"}}},
}
for _, tcase := range tests {
req, err := http.NewRequestWithContext(
context.Background(),
http.MethodPost,
server.URL+"/"+pingv1connect.PingServiceName+"/Ping",
strings.NewReader("{}"),
)
assert.Nil(t, err)
req.Header.Set("Content-Type", "application/json")
for k, v := range tcase.headers {
req.Header[k] = v
}
response, err := server.Client().Do(req)
assert.Nil(t, err)
assert.Nil(t, response.Body.Close())
assert.Equal(t, response.StatusCode, http.StatusBadRequest)
}
}

func TestBidiOverHTTP1(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
Expand Down
38 changes: 20 additions & 18 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,18 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
}

type handlerConfig struct {
CompressionPools map[string]*compressionPool
CompressionNames []string
Codecs map[string]Codec
CompressMinBytes int
Interceptor Interceptor
Procedure string
HandleGRPC bool
HandleGRPCWeb bool
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
CompressionPools map[string]*compressionPool
CompressionNames []string
Codecs map[string]Codec
CompressMinBytes int
Interceptor Interceptor
Procedure string
HandleGRPC bool
HandleGRPCWeb bool
RequireConnectProtocolHeader bool
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
}

func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig {
Expand Down Expand Up @@ -285,13 +286,14 @@ func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHan
)
for _, protocol := range protocols {
handlers = append(handlers, protocol.NewHandler(&protocolHandlerParams{
Spec: c.newSpec(streamType),
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
ReadMaxBytes: c.ReadMaxBytes,
SendMaxBytes: c.SendMaxBytes,
Spec: c.newSpec(streamType),
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
ReadMaxBytes: c.ReadMaxBytes,
SendMaxBytes: c.SendMaxBytes,
RequireConnectProtocolHeader: c.RequireConnectProtocolHeader,
}))
}
return handlers
Expand Down
18 changes: 18 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,18 @@ func WithRecover(handle func(context.Context, Spec, http.Header, any) error) Han
return WithInterceptors(&recoverHandlerInterceptor{handle: handle})
}

// WithRequireConnectProtocolHeader configures the Handler to require requests
// using the Connect RPC protocol to include the Connect-Protocol-Version
// header. This ensures that HTTP proxies and net/http middleware can easily
// identify valid Connect requests, even if they use a common Content-Type like
// application/json. However, it makes ad-hoc requests with tools like cURL
// more laborious.
//
// This option has no effect if the client uses the gRPC or gRPC-Web protocols.
func WithRequireConnectProtocolHeader() HandlerOption {
return &requireConnectProtocolHeaderOption{}
}

// Option implements both [ClientOption] and [HandlerOption], so it can be
// applied both client-side and server-side.
type Option interface {
Expand Down Expand Up @@ -381,6 +393,12 @@ func (o *handlerOptionsOption) applyToHandler(config *handlerConfig) {
}
}

type requireConnectProtocolHeaderOption struct{}

func (o *requireConnectProtocolHeaderOption) applyToHandler(config *handlerConfig) {
config.RequireConnectProtocolHeader = true
}

type grpcOption struct {
web bool
}
Expand Down
15 changes: 8 additions & 7 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ type protocol interface {
// Spec rather than constructing their own, since new fields may have been
// added.
type protocolHandlerParams struct {
Spec Spec
Codecs readOnlyCodecs
CompressionPools readOnlyCompressionPools
CompressMinBytes int
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
Spec Spec
Codecs readOnlyCodecs
CompressionPools readOnlyCompressionPools
CompressMinBytes int
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
RequireConnectProtocolHeader bool
}

// Handler is the server side of a protocol. HTTP handlers typically support
Expand Down
11 changes: 11 additions & 0 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const (
connectStreamingHeaderCompression = "Connect-Content-Encoding"
connectStreamingHeaderAcceptCompression = "Connect-Accept-Encoding"
connectHeaderTimeout = "Connect-Timeout-Ms"
connectHeaderProtocolVersion = "Connect-Protocol-Version"
connectProtocolVersion = "1"

connectFlagEnvelopeEndStream = 0b00000010

Expand Down Expand Up @@ -124,6 +126,14 @@ func (h *connectHandler) NewConn(
if failed == nil {
failed = checkServerStreamsCanFlush(h.Spec, responseWriter)
}
if failed == nil {
version := request.Header.Get(connectHeaderProtocolVersion)
if version == "" && h.RequireConnectProtocolHeader {
failed = errorf(CodeInvalidArgument, "missing required header: set %s to %q", connectHeaderProtocolVersion, connectProtocolVersion)
} else if version != "" && version != connectProtocolVersion {
failed = errorf(CodeInvalidArgument, "%s must be %q: got %q", connectHeaderProtocolVersion, connectProtocolVersion, version)
}
}

// Write any remaining headers here:
// (1) any writes to the stream will implicitly send the headers, so we
Expand Down Expand Up @@ -233,6 +243,7 @@ func (c *connectClient) WriteRequestHeader(streamType StreamType, header http.He
// We know these header keys are in canonical form, so we can bypass all the
// checks in Header.Set.
header[headerUserAgent] = []string{connectUserAgent()}
header[connectHeaderProtocolVersion] = []string{connectProtocolVersion}
header[headerContentType] = []string{
connectContentTypeFromCodecName(streamType, c.Codec.Name()),
}
Expand Down

0 comments on commit 3e30c2d

Please sign in to comment.