diff --git a/close.go b/close.go index 6c385d8b..c83f57c6 100644 --- a/close.go +++ b/close.go @@ -98,82 +98,106 @@ func CloseStatus(err error) StatusCode { // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() - return c.closeHandshake(code, reason) +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { - defer c.wg.Wait() defer errd.Wrap(&err, "failed to close WebSocket") - if c.isClosed() { + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } return net.ErrClosed } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() - c.close(nil) - c.closeMu.Lock() - defer c.closeMu.Unlock() - return c.closeErr -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() + err = c.close() - if writeErr != nil { - return writeErr + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 } + return err +} - if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { - return closeHandshakeErr +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err } + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return net.ErrClosed - } - ce := CloseError{ Code: code, Reason: reason, } var p []byte - var marshalErr error + var err error if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil + p, err = ce.bytes() + if err != nil { + return err + } } - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - if marshalErr != nil { - return marshalErr + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err } - return writeErr + return nil } func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -209,6 +233,36 @@ func (c *Conn) waitCloseHandshake() error { } } +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + ctx := c.closeReadCtx + c.closeReadMu.Unlock() + if ctx != nil { + select { + case <-ctx.Done(): + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -279,16 +333,14 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) setCloseErr(err error) { +func (c *Conn) casClosing() bool { c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil && err != nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + defer c.closeMu.Unlock() + if !c.closing { + c.closing = true + return true } + return false } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index ef4d62ad..8ba82962 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" "context" - "errors" "fmt" "io" "net" @@ -53,8 +52,9 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} // Read state. readMu *mu @@ -70,11 +70,12 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header - wg sync.WaitGroup - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool + closeReadMu sync.Mutex + closeReadCtx context.Context + + closed chan struct{} + closeMu sync.Mutex + closing bool pingCounter int32 activePingsMu sync.Mutex @@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close() }) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.timeoutLoop() - }() + go c.timeoutLoop() return c } @@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { - return - } - if err == nil { - err = c.rwc.Close() + return net.ErrClosed } - c.setCloseErrLocked(err) - - close(c.closed) runtime.SetFinalizer(c, nil) + close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. - c.rwc.Close() - - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.msgWriter.close() - c.msgReader.close() - }() + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err } func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() @@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.writeError(StatusPolicyViolation, errors.New("read timed out")) - }() + c.close() + return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close() return } } @@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } @@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected diff --git a/conn_test.go b/conn_test.go index 97b172dc..ff7279f5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -345,6 +345,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.Skip() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ diff --git a/read.go b/read.go index 8742842e..9d00ae1b 100644 --- a/read.go +++ b/read.go @@ -60,14 +60,21 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + if c.closeReadCtx != nil { + c.closeReadMu.Unlock() + return c.closeReadCtx + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() - c.wg.Add(1) go func() { - defer c.CloseNow() - defer c.wg.Done() defer cancel() + defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") @@ -222,7 +229,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.close(err) return header{}, err } } @@ -251,9 +257,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err + return n, fmt.Errorf("failed to read frame payload: %w", err) } } @@ -320,9 +324,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.close(err) return err } @@ -336,9 +338,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -411,10 +411,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { diff --git a/write.go b/write.go index 7b1152ce..d29bcb67 100644 --- a/write.go +++ b/write.go @@ -159,7 +159,6 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -242,30 +241,12 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } - - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - c.writeFrameMu.unlock() - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, net.ErrClosed - } - } defer c.writeFrameMu.unlock() select { @@ -283,7 +264,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco err = ctx.Err() default: } - c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() @@ -392,7 +372,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index 77d0d80f..2b8e3b3d 100644 --- a/ws_js.go +++ b/ws_js.go @@ -225,7 +225,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -239,7 +238,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { - defer c.wg.Wait() return c.Close(StatusGoingAway, "") }