diff --git a/connect_ext_test.go b/connect_ext_test.go index 22958553..4a07d046 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -27,7 +27,6 @@ import ( "math/rand" "net/http" "net/http/httptest" - "path" "strings" "sync" "testing" @@ -2082,48 +2081,20 @@ func TestHandlerReturnsNilResponse(t *testing.T) { func TestStreamUnexpectedEOF(t *testing.T) { t.Parallel() - testcases := map[string]http.HandlerFunc{ - "stream_unexpected_eof": func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = io.Copy(io.Discard, request.Body) - header := responseWriter.Header() - header.Set("Content-Type", "application/connect+json") - responseWriter.WriteHeader(http.StatusOK) - head := [5]byte{} - payload := []byte(`{"number": 42}`) - binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload) - }, - "stream_partial_payload": func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = io.Copy(io.Discard, request.Body) - header := responseWriter.Header() - header.Set("Content-Type", "application/connect+json") - responseWriter.WriteHeader(http.StatusOK) - head := [5]byte{} - payload := []byte(`{"number": 42}`) - binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload[:len(payload)-1]) - }, - "stream_partial_frame": func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = io.Copy(io.Discard, request.Body) - header := responseWriter.Header() - header.Set("Content-Type", "application/connect+json") - responseWriter.WriteHeader(http.StatusOK) - head := [5]byte{} - payload := []byte(`{"number": 42}`) - binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) - _, _ = responseWriter.Write(head[:4]) - }, - } + + // Initialized by the test case. + testcaseMux := make(map[string]http.HandlerFunc) mux := http.NewServeMux() mux.HandleFunc("/", func(responseWriter http.ResponseWriter, request *http.Request) { - testcase, ok := testcases[path.Base(request.Header.Get("Test-Case"))] + testcase, ok := testcaseMux[request.Header.Get("Test-Case")] if !ok { responseWriter.WriteHeader(http.StatusNotFound) return } + _, _ = io.Copy(io.Discard, request.Body) + header := responseWriter.Header() + header.Set("Content-Type", "application/connect+json") testcase(responseWriter, request) }) server := httptest.NewUnstartedServer(mux) @@ -2136,50 +2107,60 @@ func TestStreamUnexpectedEOF(t *testing.T) { server.URL, connect.WithProtoJSON(), ) - t.Run("stream_unexpected_eof", func(t *testing.T) { - t.Parallel() - const upTo = 2 - request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set("Test-Case", t.Name()) - - stream, err := client.CountUp(context.Background(), request) - assert.Nil(t, err) - for stream.Receive() { - assert.Equal(t, stream.Msg().Number, 42) - } - assert.NotNil(t, stream.Err()) - assert.Equal(t, connect.CodeOf(stream.Err()), connect.CodeUnknown) - assert.True(t, errors.Is(stream.Err(), io.ErrUnexpectedEOF)) - }) - t.Run("stream_partial_payload", func(t *testing.T) { - t.Parallel() - const upTo = 2 - request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set("Test-Case", t.Name()) - stream, err := client.CountUp(context.Background(), request) - assert.Nil(t, err) - for stream.Receive() { - assert.Equal(t, stream.Msg().Number, 42) - } - assert.NotNil(t, stream.Err()) - assert.Equal(t, connect.CodeOf(stream.Err()), connect.CodeInvalidArgument) - assert.Equal(t, stream.Err().Error(), "invalid_argument: protocol error: promised 14 bytes in enveloped message, got 13 bytes") - }) - t.Run("stream_partial_frame", func(t *testing.T) { - t.Parallel() - const upTo = 2 - request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) - request.Header().Set("Test-Case", t.Name()) - stream, err := client.CountUp(context.Background(), request) - assert.Nil(t, err) - for stream.Receive() { - assert.Equal(t, stream.Msg().Number, 42) - } - assert.NotNil(t, stream.Err()) - assert.Equal(t, connect.CodeOf(stream.Err()), connect.CodeInvalidArgument) - t.Log(stream.Err()) - assert.Equal(t, stream.Err().Error(), "invalid_argument: protocol error: incomplete envelope: unexpected EOF") - }) + head := [5]byte{} + payload := []byte(`{"number": 42}`) + binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) + testcases := []struct { + name string + handler http.HandlerFunc + expectCode connect.Code + expectMsg string + }{{ + name: "stream_unexpected_eof", + handler: func(responseWriter http.ResponseWriter, request *http.Request) { + _, _ = responseWriter.Write(head[:]) + _, _ = responseWriter.Write(payload) + }, + expectCode: connect.CodeUnknown, + expectMsg: "unknown: unexpected EOF", + }, { + name: "stream_partial_payload", + handler: func(responseWriter http.ResponseWriter, request *http.Request) { + _, _ = responseWriter.Write(head[:]) + _, _ = responseWriter.Write(payload[:len(payload)-1]) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + }, { + name: "stream_partial_frame", + handler: func(responseWriter http.ResponseWriter, request *http.Request) { + binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) + _, _ = responseWriter.Write(head[:4]) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }} + for _, testcase := range testcases { + testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler + } + for _, testcase := range testcases { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + const upTo = 2 + request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) + request.Header().Set("Test-Case", t.Name()) + stream, err := client.CountUp(context.Background(), request) + assert.Nil(t, err) + for stream.Receive() { + assert.Equal(t, stream.Msg().Number, 42) + } + assert.NotNil(t, stream.Err()) + assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode) + t.Log(stream.Err()) + assert.Equal(t, stream.Err().Error(), testcase.expectMsg) + }) + } } // TestBlankImportCodeGeneration tests that services.connect.go is generated with