diff --git a/test/end2end_test.go b/test/end2end_test.go index 88c3626c6a8d..28710b3bba56 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3188,6 +3188,117 @@ func TestServerCredsDispatch(t *testing.T) { } } +func TestFlowControlLogicalRace(t *testing.T) { + // Test for a regression of https://github.com/grpc/grpc-go/issues/632, + // and other flow control bugs. + + defer leakCheck(t)() + + const ( + itemCount = 100 + itemSize = 1 << 10 + recvCount = 2 + maxFailures = 3 + + requestTimeout = time.Second + ) + + requestCount := 10000 + if raceMode { + requestCount = 1000 + } + + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer lis.Close() + + s := grpc.NewServer() + testpb.RegisterTestServiceServer(s, &issue632server{ + itemCount: itemCount, + itemSize: itemSize, + }) + defer s.Stop() + + go s.Serve(lis) + + ctx := context.Background() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + cl := testpb.NewTestServiceClient(cc) + + failures := 0 + for i := 0; i < requestCount; i++ { + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + output, err := cl.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) + if err != nil { + t.Fatalf("StreamingOutputCall; err = %q", err) + } + + j := 0 + loop: + for ; j < recvCount; j++ { + _, err := output.Recv() + if err != nil { + if err == io.EOF { + break loop + } + switch grpc.Code(err) { + case codes.DeadlineExceeded: + break loop + default: + t.Fatalf("Recv; err = %q", err) + } + } + } + cancel() + <-ctx.Done() + + if j < recvCount { + t.Errorf("got %d responses to request %d", j, i) + failures++ + if failures >= maxFailures { + // Continue past the first failure to see if the connection is + // entirely broken, or if only a single RPC was affected + break + } + } + } +} + +type issue632server struct { + testpb.TestServiceServer + + itemSize int + itemCount int +} + +func (s *issue632server) StreamingOutputCall(req *testpb.StreamingOutputCallRequest, srv testpb.TestService_StreamingOutputCallServer) error { + for i := 0; i < s.itemCount; i++ { + err := srv.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + // Sending a large stream of data which the client reject + // helps to trigger some types of flow control bugs. + // + // Reallocating memory here is inefficient, but the stress it + // puts on the GC leads to more frequent flow control + // failures. The GC likely causes more variety in the + // goroutine scheduling orders. + Body: bytes.Repeat([]byte("a"), s.itemSize), + }, + }) + if err != nil { + return err + } + } + return nil +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { @@ -3208,6 +3319,7 @@ func interestingGoroutines() (gs []string) { strings.Contains(stack, "testing.tRunner(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "created by runtime/trace.Start") || strings.Contains(stack, "created by google3/base/go/log.init") || strings.Contains(stack, "interestingGoroutines") || strings.Contains(stack, "runtime.MHeap_Scavenger") || diff --git a/transport/control.go b/transport/control.go index 4ef0830b56ca..2586cba469c6 100644 --- a/transport/control.go +++ b/transport/control.go @@ -111,35 +111,9 @@ func newQuotaPool(q int) *quotaPool { return qb } -// add adds n to the available quota and tries to send it on acquire. -func (qb *quotaPool) add(n int) { - qb.mu.Lock() - defer qb.mu.Unlock() - qb.quota += n - if qb.quota <= 0 { - return - } - select { - case qb.c <- qb.quota: - qb.quota = 0 - default: - } -} - -// cancel cancels the pending quota sent on acquire, if any. -func (qb *quotaPool) cancel() { - qb.mu.Lock() - defer qb.mu.Unlock() - select { - case n := <-qb.c: - qb.quota += n - default: - } -} - -// reset cancels the pending quota sent on acquired, incremented by v and sends +// add cancels the pending quota sent on acquired, incremented by v and sends // it back on acquire. -func (qb *quotaPool) reset(v int) { +func (qb *quotaPool) add(v int) { qb.mu.Lock() defer qb.mu.Unlock() select { @@ -151,6 +125,10 @@ func (qb *quotaPool) reset(v int) { if qb.quota <= 0 { return } + // After the pool has been created, this is the only place that sends on + // the channel. Since mu is held at this point and any quota that was sent + // on the channel has been retrieved, we know that this code will always + // place any positive quota value on the channel. select { case qb.c <- qb.quota: qb.quota = 0 diff --git a/transport/http2_client.go b/transport/http2_client.go index cbd9f3260263..91f5bd892f0f 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -367,7 +367,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } t.mu.Unlock() if reset { - t.streamsQuota.reset(-1) + t.streamsQuota.add(-1) } // HPACK encodes various headers. Note that once WriteField(...) is @@ -604,19 +604,14 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { var p []byte if r.Len() > 0 { size := http2MaxFrameLen - s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } - t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { - if _, ok := err.(StreamError); ok || err == io.EOF { - t.sendQuotaPool.cancel() - } return err } if sq < size { @@ -1035,13 +1030,13 @@ func (t *http2Client) applySettings(ss []http2.Setting) { t.maxStreams = int(s.Val) t.mu.Unlock() if reset { - t.streamsQuota.reset(int(s.Val) - ms) + t.streamsQuota.add(int(s.Val) - ms) } case http2.SettingInitialWindowSize: t.mu.Lock() for _, stream := range t.activeStreams { // Adjust the sending quota for each stream. - stream.sendQuotaPool.reset(int(s.Val - t.streamSendQuota)) + stream.sendQuotaPool.add(int(s.Val - t.streamSendQuota)) } t.streamSendQuota = s.Val t.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index db9beb90a658..7badff84f87f 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -626,19 +626,14 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { return nil } size := http2MaxFrameLen - s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } - t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { - if _, ok := err.(StreamError); ok { - t.sendQuotaPool.cancel() - } return err } if sq < size { @@ -706,7 +701,7 @@ func (t *http2Server) applySettings(ss []http2.Setting) { t.mu.Lock() defer t.mu.Unlock() for _, stream := range t.activeStreams { - stream.sendQuotaPool.reset(int(s.Val - t.streamSendQuota)) + stream.sendQuotaPool.add(int(s.Val - t.streamSendQuota)) } t.streamSendQuota = s.Val }