diff --git a/http/h2_bundle.go b/http/h2_bundle.go index f695f9ce0c..d0c616859a 100644 --- a/http/h2_bundle.go +++ b/http/h2_bundle.go @@ -824,7 +824,7 @@ func (c *http2dialCall) dial(addr string) { // This code decides which ones live or die. // The return value used is whether c was used. // c is never closed. -func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn, createStreamPlz bool) (used bool, err error) { +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn, createStream bool) (used bool, err error) { p.mu.Lock() for _, cc := range p.conns[key] { if cc.CanTakeNewRequest() { @@ -842,7 +842,7 @@ func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c * done: make(chan struct{}), } p.addConnCalls[key] = call - go call.run(t, key, c, createStreamPlz) + go call.run(t, key, c, createStream) } p.mu.Unlock() @@ -859,11 +859,11 @@ type http2addConnCall struct { err error } -func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn, createStreamPlz bool) { +func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn, createStream bool) { var cc *http2ClientConn var err error - if createStreamPlz { + if createStream { cc, err = t.NewClientConnWithStream(tc) } else { cc, err = t.NewClientConn(tc) @@ -6680,13 +6680,13 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) { return t2 } + //TODO(gerg): should this live in the h2c package? fmt.Println("Registering proto callbacks!") h2cUpgradeFn := func(authority string, c *tls.Conn) UpgradableRoundTripper { addr := http2authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c, true); err != nil { go c.Close() - //TODO(gerg): return a proper error here instead of nil. http2erringUpgradableRoundTripper? - return nil //http2erringRoundTripper{err} + return http2erringRoundTripper{err} } else if !used { // Turns out we don't need this c. // For example, two goroutines made requests to the same host @@ -7094,8 +7094,8 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie if err != nil { return nil, err } - // TODO(gerg): Just added false here to make it compile. Is that right? - return t.newClientConn(tconn, singleUse, false) + doNotCreateStream := false + return t.newClientConn(tconn, singleUse, doNotCreateStream) } func (t *http2Transport) newTLSConfig(host string) *tls.Config { @@ -7164,7 +7164,7 @@ func (t *http2Transport) NewClientConnWithStream(c net.Conn) (*http2ClientConn, return t.newClientConn(c, false, true) } -func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, createStreamPlz bool) (*http2ClientConn, error) { +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, createStream bool) (*http2ClientConn, error) { cc := &http2ClientConn{ t: t, tconn: c, @@ -7202,11 +7202,10 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, createStreamP // henc in response to SETTINGS frames? cc.henc = hpack.NewEncoder(&cc.hbuf) - if createStreamPlz { + if createStream { cc.newStream() } - //TODO(greg): Do we need to trigger this to get the stream IDs correct? The stream that is - // created as part of the h2c upgrade should be stream ID 1 + if t.AllowHTTP { cc.nextStreamID = 3 } @@ -7524,21 +7523,14 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } func (cc *http2ClientConn) completeUpgrade(req *Request) (res *Response, err error) { - var respHeaderTimer <-chan time.Time - - // TODO(gerg): Try and use a shared helper function with cc.roundTrip to prevent - // duplicate logic cs := cc.streams[1] cs.req = req cs.trace = httptrace.ContextClientTrace(req.Context()) - //TODO(gerg): Is it okay that we have any empty one of these? - bodyWriter := http2bodyWriterState{} - - //TODO(gerg): Is it okay that we hard-code this to false? + bodyWriter := cc.t.getBodyWriterState(cs, req.Body) + bodyWriter.written = true hasBody := false - bodyWritten := true - resp, _, err := cc.returnTrip(req, cs, bodyWriter, hasBody, respHeaderTimer, bodyWritten) + resp, _, err := cc.returnTrip(req, cs, bodyWriter, hasBody) return resp, err } @@ -7622,8 +7614,14 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe return nil, false, werr } + return cc.returnTrip(req, cs, bodyWriter, hasBody) +} + +// TODO(gerg): Method signature, use new struct? +func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyWriter http2bodyWriterState, hasBody bool) (res *Response, gotErrAfterReqBodyWrite bool, err error) { + var respHeaderTimer <-chan time.Time - if hasBody { + if hasBody && !bodyWriter.written { bodyWriter.scheduleBodyWrite() } else { http2traceWroteRequest(cs.trace, nil) @@ -7633,10 +7631,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe respHeaderTimer = timer.C } } - return cc.returnTrip(req, cs, bodyWriter, hasBody, respHeaderTimer, false) -} -func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyWriter http2bodyWriterState, hasBody bool, respHeaderTimer <-chan time.Time, bodyWritten bool) (res *Response, gotErrAfterReqBodyWrite bool, err error) { readLoopResCh := cs.resc ctx := req.Context() @@ -7670,7 +7665,7 @@ func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyW return handleReadLoopResponse(re) // TODO(gerg): Understand what publishes to this channel case <-respHeaderTimer: - if !hasBody || bodyWritten { + if !hasBody || bodyWriter.written { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() @@ -7679,7 +7674,7 @@ func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyW cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), http2errTimeout case <-ctx.Done(): - if !hasBody || bodyWritten { + if !hasBody || bodyWriter.written { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() @@ -7688,7 +7683,7 @@ func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyW cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), ctx.Err() case <-req.Cancel: - if !hasBody || bodyWritten { + if !hasBody || bodyWriter.written { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { bodyWriter.cancel() @@ -7712,7 +7707,7 @@ func (cc *http2ClientConn) returnTrip(req *Request, cs *http2clientStream, bodyW cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), err } - bodyWritten = true + bodyWriter.written = true if d := cc.responseHeaderTimeout(); d != 0 { timer := time.NewTimer(d) defer timer.Stop() @@ -8989,7 +8984,8 @@ func http2strSliceContains(ss []string, s string) bool { type http2erringRoundTripper struct{ err error } -func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } +func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } +func (rt http2erringRoundTripper) CompleteUpgrade(*Request) (*Response, error) { return nil, rt.err } // gzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read @@ -9025,12 +9021,13 @@ func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } // of the request body, particularly regarding doing delayed writes of the body // when the request contains "Expect: 100-continue". type http2bodyWriterState struct { - cs *http2clientStream - timer *time.Timer // if non-nil, we're doing a delayed write - fnonce *sync.Once // to call fn with - fn func() // the code to run in the goroutine, writing the body - resc chan error // result of fn's execution - delay time.Duration // how long we should delay a delayed write for + cs *http2clientStream + timer *time.Timer // if non-nil, we're doing a delayed write + fnonce *sync.Once // to call fn with + fn func() // the code to run in the goroutine, writing the body + resc chan error // result of fn's execution + delay time.Duration // how long we should delay a delayed write for + written bool // if the body has completed writing } func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { @@ -9038,6 +9035,7 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade if body == nil { return } + s.written = false resc := make(chan error, 1) s.resc = resc s.fn = func() {