diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 1a9ce5ab40b71f..3a601d304b99ff 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -14,6 +14,7 @@ import ( "io" "io/ioutil" "log" + "net" . "net/http" "net/http/httptest" "net/url" @@ -22,7 +23,9 @@ import ( "sort" "strings" "sync" + "sync/atomic" "testing" + "time" ) type clientServerTest struct { @@ -861,3 +864,93 @@ func testStarRequest(t *testing.T, method string, h2 bool) { t.Errorf("RequestURI = %q; want *", req.RequestURI) } } + +// Issue 13957 +func TestTransportDiscardsUnneededConns(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) + })) + defer cst.close() + + var numOpen, numClose int32 // atomic + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + tr := &Transport{ + TLSClientConfig: tlsConfig, + DialTLS: func(_, addr string) (net.Conn, error) { + time.Sleep(10 * time.Millisecond) + rc, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + atomic.AddInt32(&numOpen, 1) + c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} + return tls.Client(c, tlsConfig), nil + }, + } + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatal(err) + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + + const N = 10 + gotBody := make(chan string, N) + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := c.Get(cst.ts.URL) + if err != nil { + t.Errorf("Get: %v", err) + return + } + defer resp.Body.Close() + slurp, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + gotBody <- string(slurp) + }() + } + wg.Wait() + close(gotBody) + + var last string + for got := range gotBody { + if last == "" { + last = got + continue + } + if got != last { + t.Errorf("Response body changed: %q -> %q", last, got) + } + } + + var open, close int32 + for i := 0; i < 150; i++ { + open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) + if open < 1 { + t.Fatalf("open = %d; want at least", open) + } + if close == open-1 { + // Success + return + } + time.Sleep(10 * time.Millisecond) + } + t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) +} + +type noteCloseConn struct { + net.Conn + closeFunc func() +} + +func (x noteCloseConn) Close() error { + x.closeFunc() + return x.Conn.Close() +} diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 7aae52eabcb6df..e743737f54caef 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -24,6 +24,7 @@ import ( "encoding/binary" "errors" "fmt" + "golang.org/x/net/http2/hpack" "io" "io/ioutil" "log" @@ -37,8 +38,6 @@ import ( "strings" "sync" "time" - - "golang.org/x/net/http2/hpack" ) // ClientConnPool manages a pool of HTTP/2 client connections. @@ -47,21 +46,29 @@ type http2ClientConnPool interface { MarkDead(*http2ClientConn) } +// TODO: use singleflight for dialing and addConnCalls? type http2clientConnPool struct { - t *http2Transport + t *http2Transport + mu sync.Mutex // TODO: maybe switch to RWMutex // TODO: add support for sharing conns based on cert names // (e.g. share conn for googleapis.com and appspot.com) - conns map[string][]*http2ClientConn // key is host:port - dialing map[string]*http2dialCall // currently in-flight dials - keys map[*http2ClientConn][]string + conns map[string][]*http2ClientConn // key is host:port + dialing map[string]*http2dialCall // currently in-flight dials + keys map[*http2ClientConn][]string + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls } func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, true) + return p.getClientConn(req, addr, http2dialOnMiss) } -func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { +const ( + http2dialOnMiss = true + http2noDialOnMiss = false +) + +func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { p.mu.Lock() for _, cc := range p.conns[addr] { if cc.CanTakeNewRequest() { @@ -115,6 +122,64 @@ func (c *http2dialCall) dial(addr string) { c.p.mu.Unlock() } +// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't +// already exist. It coalesces concurrent calls with the same key. +// This is used by the http1 Transport code when it creates a new connection. Because +// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know +// the protocol), it can get into a situation where it has multiple TLS connections. +// This code decides which ones live or die. +// The return value used is whether c was used. +// c is never closed. +func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) { + p.mu.Lock() + for _, cc := range p.conns[key] { + if cc.CanTakeNewRequest() { + p.mu.Unlock() + return false, nil + } + } + call, dup := p.addConnCalls[key] + if !dup { + if p.addConnCalls == nil { + p.addConnCalls = make(map[string]*http2addConnCall) + } + call = &http2addConnCall{ + p: p, + done: make(chan struct{}), + } + p.addConnCalls[key] = call + go call.run(t, key, c) + } + p.mu.Unlock() + + <-call.done + if call.err != nil { + return false, call.err + } + return !dup, nil +} + +type http2addConnCall struct { + p *http2clientConnPool + done chan struct{} // closed when done + err error +} + +func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { + cc, err := t.NewClientConn(tc) + + p := c.p + p.mu.Lock() + if err != nil { + c.err = err + } else { + p.addConnLocked(key, cc) + } + delete(p.addConnCalls, key) + p.mu.Unlock() + close(c.done) +} + func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) { p.mu.Lock() p.addConnLocked(key, cc) @@ -197,12 +262,14 @@ func http2configureTransport(t1 *Transport) error { t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") } upgradeFn := func(authority string, c *tls.Conn) RoundTripper { - cc, err := t2.NewClientConn(c) - if err != nil { - c.Close() + addr := http2authorityAddr(authority) + if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { + go c.Close() return http2erringRoundTripper{err} + } else if !used { + + go c.Close() } - connPool.addConn(http2authorityAddr(authority), cc) return t2 } if m := t1.TLSNextProto; len(m) == 0 { @@ -233,8 +300,7 @@ func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) { type http2noDialClientConnPool struct{ *http2clientConnPool } func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { - const doDial = false - return p.getClientConn(req, addr, doDial) + return p.getClientConn(req, addr, http2noDialOnMiss) } // noDialH2RoundTripper is a RoundTripper which only tries to complete the request diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 9378b8385e9d46..9bef9026d4c6aa 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -713,6 +713,12 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)") } if tc, ok := pconn.conn.(*tls.Conn); ok { + // Handshake here, in case DialTLS didn't. TLSNextProto below + // depends on it for knowing the connection state. + if err := tc.Handshake(); err != nil { + go pconn.conn.Close() + return nil, err + } cs := tc.ConnectionState() pconn.tlsState = &cs }