Skip to content

Commit

Permalink
Fix flaky TestStreamForServer test (#627)
Browse files Browse the repository at this point in the history
This test was flaky due to unexpected cases where calling `Send(nil)` to
send request headers could return a non-nil error, even though the server
did not actually send back an error. If the handler immediately returned
without accepting any messages, the end-of-stream could arrive quickly
enough that the call to `Send` would notice and return `io.EOF`. But since
the call to `Send` was intended to initiate the request and the server did
not actually return any error, this is confusing to observe on the client. So
now a call to `Send(nil)` that initiates the call will return a nil error. And
any error (including EOF) already received from the server will be returned
by the next call to `Send` or `Receive`.
  • Loading branch information
emcfarlane authored Nov 10, 2023
1 parent 67dceff commit 1e3c4e7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
23 changes: 13 additions & 10 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,8 @@ func TestBidiStreamServerSendsFirstMessage(t *testing.T) {

func TestStreamForServer(t *testing.T) {
t.Parallel()
newPingClient := func(pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient {
newPingClient := func(t *testing.T, pingServer pingv1connect.PingServiceHandler) pingv1connect.PingServiceClient {
t.Helper()
mux := http.NewServeMux()
mux.Handle(pingv1connect.NewPingServiceHandler(pingServer))
server := memhttptest.NewServer(t, mux)
Expand All @@ -1523,7 +1524,7 @@ func TestStreamForServer(t *testing.T) {
}
t.Run("not-proto-message", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error {
return stream.Conn().Send("foobar")
},
Expand All @@ -1537,7 +1538,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("nil-message", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error {
return stream.Send(nil)
},
Expand All @@ -1551,7 +1552,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("get-spec", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
cumSum: func(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error {
assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeBidi)
assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceCumSumProcedure)
Expand All @@ -1565,7 +1566,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("server-stream", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer)
assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure)
Expand All @@ -1581,7 +1582,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("server-stream-send", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1}))
return nil
Expand All @@ -1597,7 +1598,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("server-stream-send-nil", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error {
stream.ResponseHeader().Set("foo", "bar")
stream.ResponseTrailer().Set("bas", "blah")
Expand All @@ -1618,7 +1619,7 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("client-stream", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient)
assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure)
Expand All @@ -1639,8 +1640,9 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("client-stream-conn", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
assert.True(t, stream.Receive())
assert.NotNil(t, stream.Conn().Send("not-proto"))
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
Expand All @@ -1653,8 +1655,9 @@ func TestStreamForServer(t *testing.T) {
})
t.Run("client-stream-send-msg", func(t *testing.T) {
t.Parallel()
client := newPingClient(&pluggablePingServer{
client := newPingClient(t, &pluggablePingServer{
sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) {
assert.True(t, stream.Receive())
assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2}))
return connect.NewResponse(&pingv1.SumResponse{}), nil
},
Expand Down
23 changes: 15 additions & 8 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"io"
"net/http"
"net/url"
"sync"
"sync/atomic"
)

// duplexHTTPCall is a full-duplex stream between the client and server. The
Expand All @@ -42,9 +42,9 @@ type duplexHTTPCall struct {
requestBodyReader *io.PipeReader
requestBodyWriter *io.PipeWriter

// sendRequestOnce ensures we only send the request once.
sendRequestOnce sync.Once
request *http.Request
// requestSent ensures we only send the request once.
requestSent atomic.Bool
request *http.Request

// responseReady is closed when the response is ready or when the request
// fails. Any error on request initialisation will be set on the
Expand Down Expand Up @@ -96,11 +96,16 @@ func newDuplexHTTPCall(

// Write to the request body.
func (d *duplexHTTPCall) Write(data []byte) (int, error) {
d.ensureRequestMade()
isFirst := d.ensureRequestMade()
// Before we send any data, check if the context has been canceled.
if err := d.ctx.Err(); err != nil {
return 0, wrapIfContextError(err)
}
if isFirst && data == nil {
// On first write a nil Send is used to send request headers. Avoid
// writing a zero-length payload to avoid superfluous errors with close.
return 0, nil
}
// It's safe to write to this side of the pipe while net/http concurrently
// reads from the other side.
bytesWritten, err := d.requestBodyWriter.Write(data)
Expand Down Expand Up @@ -229,10 +234,12 @@ func (d *duplexHTTPCall) BlockUntilResponseReady() error {
// ensureRequestMade sends the request headers and starts the response stream.
// It is not safe to call this concurrently. Write and CloseWrite call this but
// ensure that they're not called concurrently.
func (d *duplexHTTPCall) ensureRequestMade() {
d.sendRequestOnce.Do(func() {
func (d *duplexHTTPCall) ensureRequestMade() (isFirst bool) {
if d.requestSent.CompareAndSwap(false, true) {
go d.makeRequest()
})
return true
}
return false
}

func (d *duplexHTTPCall) makeRequest() {
Expand Down

0 comments on commit 1e3c4e7

Please sign in to comment.