From 541dbe58b6bc869fc1c7de361846682a34365325 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 9 May 2024 10:24:28 -0700 Subject: [PATCH] http2: add Server.WriteByteTimeout Transports support a WriteByteTimeout option which sets the maximum amount of time we can go without being able to write any bytes to a connection. Add an equivalent option to Server for consistency. Fixes golang/go#61777 Change-Id: Iaa8a69dfc403906eb224829320f901e5a6a5c429 Reviewed-on: https://go-review.googlesource.com/c/net/+/601496 LUCI-TryBot-Result: Go LUCI Reviewed-by: Carlos Amedee Reviewed-by: Brad Fitzpatrick --- http2/connframes_test.go | 5 ++-- http2/http2.go | 53 ++++++++++++++++++++++++++++++++++------ http2/server.go | 12 ++++++++- http2/server_test.go | 32 ++++++++++++++++++++++++ http2/transport.go | 24 ++++++------------ 5 files changed, 98 insertions(+), 28 deletions(-) diff --git a/http2/connframes_test.go b/http2/connframes_test.go index 7db8b74e2..1f7834c01 100644 --- a/http2/connframes_test.go +++ b/http2/connframes_test.go @@ -6,7 +6,6 @@ package http2 import ( "bytes" - "context" "io" "net/http" "os" @@ -295,7 +294,7 @@ func (tf *testConnFramer) wantClosed() { if err == nil { tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr) } - if err == context.DeadlineExceeded { + if err == os.ErrDeadlineExceeded { tf.t.Fatalf("connection is not closed; want it to be") } } @@ -306,7 +305,7 @@ func (tf *testConnFramer) wantIdle() { if err == nil { tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr) } - if err != context.DeadlineExceeded { + if err != os.ErrDeadlineExceeded { tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err) } } diff --git a/http2/http2.go b/http2/http2.go index 003e649f3..7688c356b 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -19,8 +19,9 @@ import ( "bufio" "context" "crypto/tls" + "errors" "fmt" - "io" + "net" "net/http" "os" "sort" @@ -237,13 +238,19 @@ func (cw closeWaiter) Wait() { // Its buffered writer is lazily allocated as needed, to minimize // idle memory usage with many connections. type bufferedWriter struct { - _ incomparable - w io.Writer // immutable - bw *bufio.Writer // non-nil when data is buffered + _ incomparable + group synctestGroupInterface // immutable + conn net.Conn // immutable + bw *bufio.Writer // non-nil when data is buffered + byteTimeout time.Duration // immutable, WriteByteTimeout } -func newBufferedWriter(w io.Writer) *bufferedWriter { - return &bufferedWriter{w: w} +func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter { + return &bufferedWriter{ + group: group, + conn: conn, + byteTimeout: timeout, + } } // bufWriterPoolBufferSize is the size of bufio.Writer's @@ -270,7 +277,7 @@ func (w *bufferedWriter) Available() int { func (w *bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := bufWriterPool.Get().(*bufio.Writer) - bw.Reset(w.w) + bw.Reset((*bufferedWriterTimeoutWriter)(w)) w.bw = bw } return w.bw.Write(p) @@ -288,6 +295,38 @@ func (w *bufferedWriter) Flush() error { return err } +type bufferedWriterTimeoutWriter bufferedWriter + +func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) { + return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p) +} + +// writeWithByteTimeout writes to conn. +// If more than timeout passes without any bytes being written to the connection, +// the write fails. +func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) { + if timeout <= 0 { + return conn.Write(p) + } + for { + var now time.Time + if group == nil { + now = time.Now() + } else { + now = group.Now() + } + conn.SetWriteDeadline(now.Add(timeout)) + nn, err := conn.Write(p[n:]) + n += nn + if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) { + // Either we finished the write, made no progress, or hit the deadline. + // Whichever it is, we're done now. + conn.SetWriteDeadline(time.Time{}) + return n, err + } + } +} + func mustUint31(v int32) uint32 { if v < 0 || v > 2147483647 { panic("out of range") diff --git a/http2/server.go b/http2/server.go index 6c349f3ec..b16173c57 100644 --- a/http2/server.go +++ b/http2/server.go @@ -127,6 +127,12 @@ type Server struct { // If zero or negative, there is no timeout. IdleTimeout time.Duration + // WriteByteTimeout is the timeout after which a connection will be + // closed if no data can be written to it. The timeout begins when data is + // available to write, and is extended whenever any bytes are written. + // If zero or negative, there is no timeout. + WriteByteTimeout time.Duration + // MaxUploadBufferPerConnection is the size of the initial flow // control window for each connections. The HTTP/2 spec does not // allow this to be smaller than 65535 or larger than 2^32-1. @@ -446,7 +452,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon conn: c, baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), - bw: newBufferedWriter(c), + bw: newBufferedWriter(s.group, c, s.WriteByteTimeout), handler: opts.handler(), streams: make(map[uint32]*stream), readFrameCh: make(chan readFrameResult), @@ -1320,6 +1326,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { sc.writingFrame = false sc.writingFrameAsync = false + if res.err != nil { + sc.conn.Close() + } + wr := res.wr if writeEndsStream(wr.write) { diff --git a/http2/server_test.go b/http2/server_test.go index 47c3c619c..ab53c2694 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4674,3 +4674,35 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) { } resp.Body.Close() } + +func TestServerWriteByteTimeout(t *testing.T) { + const timeout = 1 * time.Second + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Write(make([]byte, 100)) + }, func(s *Server) { + s.WriteByteTimeout = timeout + }) + st.greet() + + st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + + // Read a few bytes, staying just under WriteByteTimeout. + for i := 0; i < 10; i++ { + st.advance(timeout - 1) + if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil { + t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err) + } + } + + // Wait for WriteByteTimeout. + // The connection should close. + st.advance(1 * time.Second) // timeout after writing one byte + st.advance(1 * time.Second) // timeout after failing to write any more bytes + st.wantClosed() +} diff --git a/http2/transport.go b/http2/transport.go index 61f511f97..49fc792ad 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -25,7 +25,6 @@ import ( "net/http" "net/http/httptrace" "net/textproto" - "os" "sort" "strconv" "strings" @@ -499,6 +498,7 @@ func (cs *clientStream) closeReqBodyLocked() { } type stickyErrWriter struct { + group synctestGroupInterface conn net.Conn timeout time.Duration err *error @@ -508,22 +508,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } - for { - if sew.timeout != 0 { - sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) - } - nn, err := sew.conn.Write(p[n:]) - n += nn - if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { - // Keep extending the deadline so long as we're making progress. - continue - } - if sew.timeout != 0 { - sew.conn.SetWriteDeadline(time.Time{}) - } - *sew.err = err - return n, err - } + n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p) + *sew.err = err + return n, err } // noCachedConnError is the concrete type of ErrNoCachedConn, which @@ -792,10 +779,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), } + var group synctestGroupInterface if t.transportTestHooks != nil { t.markNewGoroutine() t.transportTestHooks.newclientconn(cc) c = cc.tconn + group = t.group } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -807,6 +796,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. cc.bw = bufio.NewWriter(stickyErrWriter{ + group: group, conn: c, timeout: t.WriteByteTimeout, err: &cc.werr,