diff --git a/http2/transport.go b/http2/transport.go index 39a1a46af..e0dfe9f6a 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -165,6 +165,7 @@ type ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*clientStream // client-initiated nextStreamID uint32 + pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader @@ -217,35 +218,45 @@ type clientStream struct { resTrailer *http.Header // client's Response.Trailer } -// awaitRequestCancel runs in its own goroutine and waits for the user -// to cancel a RoundTrip request, its context to expire, or for the -// request to be done (any way it might be removed from the cc.streams -// map: peer reset, successful completion, TCP connection breakage, -// etc) -func (cs *clientStream) awaitRequestCancel(req *http.Request) { +// awaitRequestCancel waits for the user to cancel a request or for the done +// channel to be signaled. A non-nil error is returned only if the request was +// canceled. +func awaitRequestCancel(req *http.Request, done <-chan struct{}) error { ctx := reqContext(req) if req.Cancel == nil && ctx.Done() == nil { - return + return nil } select { case <-req.Cancel: - cs.cancelStream() - cs.bufPipe.CloseWithError(errRequestCanceled) + return errRequestCanceled case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} + +// awaitRequestCancel waits for the user to cancel a request, its context to +// expire, or for the request to be done (any way it might be removed from the +// cc.streams map: peer reset, successful completion, TCP connection breakage, +// etc). If the request is canceled, then cs will be canceled and closed. +func (cs *clientStream) awaitRequestCancel(req *http.Request) { + if err := awaitRequestCancel(req, cs.done); err != nil { cs.cancelStream() - cs.bufPipe.CloseWithError(ctx.Err()) - case <-cs.done: + cs.bufPipe.CloseWithError(err) } } func (cs *clientStream) cancelStream() { - cs.cc.mu.Lock() + cc := cs.cc + cc.mu.Lock() didReset := cs.didReset cs.didReset = true - cs.cc.mu.Unlock() + cc.mu.Unlock() if !didReset { - cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) + cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) + cc.forgetStreamID(cs.ID) } } @@ -594,6 +605,8 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { } } +// CanTakeNewRequest reports whether the connection can take a new request, +// meaning it has not been closed or received or sent a GOAWAY. func (cc *ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() @@ -605,8 +618,7 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool { return false } return cc.goAway == nil && !cc.closed && - int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && - cc.nextStreamID < math.MaxInt32 + int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32 } // onIdleTimeout is called from a time.AfterFunc goroutine. It will @@ -752,10 +764,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { hasTrailers := trailers != "" cc.mu.Lock() - cc.lastActive = time.Now() - if cc.closed || !cc.canTakeNewRequestLocked() { + if err := cc.awaitOpenSlotForRequest(req); err != nil { cc.mu.Unlock() - return nil, errClientConnUnusable + return nil, err } body := req.Body @@ -869,31 +880,31 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { case re := <-readLoopResCh: return handleReadLoopResponse(re) case <-respHeaderTimer: - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } + cc.forgetStreamID(cs.ID) return nil, errTimeout case <-ctx.Done(): - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } + cc.forgetStreamID(cs.ID) return nil, ctx.Err() case <-req.Cancel: - cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } + cc.forgetStreamID(cs.ID) return nil, errRequestCanceled case <-cs.peerReset: // processResetStream already removed the @@ -920,6 +931,45 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } } +// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// Must hold cc.mu. +func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error { + var waitingForConn chan struct{} + var waitingForConnErr error // guarded by cc.mu + for { + cc.lastActive = time.Now() + if cc.closed || !cc.canTakeNewRequestLocked() { + return errClientConnUnusable + } + if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) { + if waitingForConn != nil { + close(waitingForConn) + } + return nil + } + // Unfortunately, we cannot wait on a condition variable and channel at + // the same time, so instead, we spin up a goroutine to check if the + // request is canceled while we wait for a slot to open in the connection. + if waitingForConn == nil { + waitingForConn = make(chan struct{}) + go func() { + if err := awaitRequestCancel(req, waitingForConn); err != nil { + cc.mu.Lock() + waitingForConnErr = err + cc.cond.Broadcast() + cc.mu.Unlock() + } + }() + } + cc.pendingRequests++ + cc.cond.Wait() + cc.pendingRequests-- + if waitingForConnErr != nil { + return waitingForConnErr + } + } +} + // requires cc.wmu be held func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { first := true // first frame written (HEADERS is first, then CONTINUATION) @@ -1279,7 +1329,9 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream { cc.idleTimer.Reset(cc.idleTimeout) } close(cs.done) - cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + // Wake up checkResetOrDone via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() } return cs } @@ -1378,8 +1430,9 @@ func (rl *clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(StreamError); ok { - if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil { + if cs := cc.streamByID(se.StreamID, false); cs != nil { cs.cc.writeStreamReset(cs.ID, se.Code, err) + cs.cc.forgetStreamID(cs.ID) if se.Cause == nil { se.Cause = cc.fr.errDetail } @@ -1701,6 +1754,7 @@ func (b transportResponseBody) Close() error { } cs.bufPipe.BreakWithError(errClosedResponseBody) + cc.forgetStreamID(cs.ID) return nil } diff --git a/http2/transport_test.go b/http2/transport_test.go index 0e7b80103..ac4661f48 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -685,7 +685,7 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func (ct *clientTester) greet() { +func (ct *clientTester) greet(settings ...Setting) { buf := make([]byte, len(ClientPreface)) _, err := io.ReadFull(ct.sc, buf) if err != nil { @@ -699,7 +699,7 @@ func (ct *clientTester) greet() { ct.t.Fatalf("Wanted client settings frame; got %v", f) _ = sf // stash it away? } - if err := ct.fr.WriteSettings(); err != nil { + if err := ct.fr.WriteSettings(settings...); err != nil { ct.t.Fatal(err) } if err := ct.fr.WriteSettingsAck(); err != nil { @@ -3036,6 +3036,175 @@ func TestTransportRetryHasLimit(t *testing.T) { ct.run() } +func TestTransportRequestsStallAtServerLimit(t *testing.T) { + const maxConcurrent = 2 + + greet := make(chan struct{}) // server sends initial SETTINGS frame + gotRequest := make(chan struct{}) // server received a request + clientDone := make(chan struct{}) + + // Collect errors from goroutines. + var wg sync.WaitGroup + errs := make(chan error, 100) + defer func() { + wg.Wait() + close(errs) + for err := range errs { + t.Error(err) + } + }() + + // We will send maxConcurrent+2 requests. This checker goroutine waits for the + // following stages: + // 1. The first maxConcurrent requests are received by the server. + // 2. The client will cancel the next request + // 3. The server is unblocked so it can service the first maxConcurrent requests + // 4. The client will send the final request + wg.Add(1) + unblockClient := make(chan struct{}) + clientRequestCancelled := make(chan struct{}) + unblockServer := make(chan struct{}) + go func() { + defer wg.Done() + // Stage 1. + for k := 0; k < maxConcurrent; k++ { + <-gotRequest + } + // Stage 2. + close(unblockClient) + <-clientRequestCancelled + // Stage 3: give some time for the final RoundTrip call to be scheduled and + // verify that the final request is not sent. + time.Sleep(50 * time.Millisecond) + select { + case <-gotRequest: + errs <- errors.New("last request did not stall") + close(unblockServer) + return + default: + } + close(unblockServer) + // Stage 4. + <-gotRequest + }() + + ct := newClientTester(t) + ct.client = func() error { + var wg sync.WaitGroup + defer func() { + wg.Wait() + close(clientDone) + ct.cc.(*net.TCPConn).CloseWrite() + }() + for k := 0; k < maxConcurrent+2; k++ { + wg.Add(1) + go func(k int) { + defer wg.Done() + // Don't send the second request until after receiving SETTINGS from the server + // to avoid a race where we use the default SettingMaxConcurrentStreams, which + // is much larger than maxConcurrent. We have to send the first request before + // waiting because the first request triggers the dial and greet. + if k > 0 { + <-greet + } + // Block until maxConcurrent requests are sent before sending any more. + if k >= maxConcurrent { + <-unblockClient + } + req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil) + if k == maxConcurrent { + // This request will be canceled. + cancel := make(chan struct{}) + req.Cancel = cancel + close(cancel) + _, err := ct.tr.RoundTrip(req) + close(clientRequestCancelled) + if err == nil { + errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k) + return + } + } else { + resp, err := ct.tr.RoundTrip(req) + if err != nil { + errs <- fmt.Errorf("RoundTrip(%d): %v", k, err) + return + } + ioutil.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != 204 { + errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode) + return + } + } + }(k) + } + return nil + } + + ct.server = func() error { + var wg sync.WaitGroup + defer wg.Wait() + + ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) + + // Server write loop. + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + writeResp := make(chan uint32, maxConcurrent+1) + + wg.Add(1) + go func() { + defer wg.Done() + <-unblockServer + for id := range writeResp { + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: id, + EndHeaders: true, + EndStream: true, + BlockFragment: buf.Bytes(), + }) + } + }() + + // Server read loop. + var nreq int + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it will have reported any errors on its side. + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame: + case *SettingsFrame: + // Wait for the client SETTINGS ack until ending the greet. + close(greet) + case *HeadersFrame: + if !f.HeadersEnded() { + return fmt.Errorf("headers should have END_HEADERS be ended: %v", f) + } + gotRequest <- struct{}{} + nreq++ + writeResp <- f.StreamID + if nreq == maxConcurrent+1 { + close(writeResp) + } + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + + ct.run() +} + func TestAuthorityAddr(t *testing.T) { tests := []struct { scheme, authority string