diff --git a/accept_test.go b/accept_test.go index 7cb85d0f..18233b1e 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "nhooyr.io/websocket/internal/test/assert" @@ -142,6 +143,42 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + t.Run("closeRace", func(t *testing.T) { + t.Parallel() + + server, _ := net.Pipe() + + rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return server, rw, nil + }, + } + } + w := newResponseWriter() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + c, err := Accept(w, r, nil) + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.Close(StatusInternalError, "the sky is falling") + wg.Done() + }() + go func() { + c.CloseNow() + wg.Done() + }() + wg.Wait() + assert.Success(t, err) + }) } func Test_verifyClientHandshake(t *testing.T) { diff --git a/close.go b/close.go index e925d043..625ed121 100644 --- a/close.go +++ b/close.go @@ -97,80 +97,106 @@ func CloseStatus(err error) StatusCode { // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() - return c.closeHandshake(code, reason) +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { - defer c.wg.Wait() defer errd.Wrap(&err, "failed to close WebSocket") - if c.isClosed() { + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } return net.ErrClosed } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() - c.close(nil) - return c.closeErr -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() + err = c.close() - if writeErr != nil { - return writeErr + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 } + return err +} - if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { - return closeHandshakeErr +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err } + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return net.ErrClosed - } - ce := CloseError{ Code: code, Reason: reason, } var p []byte - var marshalErr error + var err error if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil + p, err = ce.bytes() + if err != nil { + return err + } } - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - if marshalErr != nil { - return marshalErr + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err } - return writeErr + return nil } func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -180,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { @@ -206,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error { } } +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + closeRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closeRead { + select { + case <-c.closeReadDone: + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -276,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) setCloseErr(err error) { +func (c *Conn) casClosing() bool { c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil && err != nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + defer c.closeMu.Unlock() + if !c.closing { + c.closing = true + return true } + return false } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index ef4d62ad..8690fb3b 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" "context" - "errors" "fmt" "io" "net" @@ -53,15 +52,15 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader // Write state. msgWriter *msgWriter @@ -70,11 +69,13 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header - wg sync.WaitGroup - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closed chan struct{} + closeMu sync.Mutex + closing bool pingCounter int32 activePingsMu sync.Mutex @@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close() }) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.timeoutLoop() - }() + go c.timeoutLoop() return c } @@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { - return - } - if err == nil { - err = c.rwc.Close() + return net.ErrClosed } - c.setCloseErrLocked(err) - - close(c.closed) runtime.SetFinalizer(c, nil) + close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. - c.rwc.Close() - - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.msgWriter.close() - c.msgReader.close() - }() + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err } func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() @@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.writeError(StatusPolicyViolation, errors.New("read timed out")) - }() + c.close() + return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close() return } } @@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } @@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected diff --git a/conn_test.go b/conn_test.go index 97b172dc..9fbe961d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -345,6 +345,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.Skip() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ @@ -360,7 +363,7 @@ func TestWasm(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() diff --git a/mask.go b/mask.go index 5f0746dc..7bc0c8d5 100644 --- a/mask.go +++ b/mask.go @@ -16,8 +16,6 @@ import ( // to be in little endian. // // See https://github.com/golang/go/issues/31586 -// -//lint:ignore U1000 mask.go func maskGo(b []byte, key uint32) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) diff --git a/mask_asm.go b/mask_asm.go index 259eec03..f9484b5b 100644 --- a/mask_asm.go +++ b/mask_asm.go @@ -3,10 +3,14 @@ package websocket func mask(b []byte, key uint32) uint32 { - if len(b) > 0 { - return maskAsm(&b[0], len(b), key) - } - return key + // TODO: Will enable in v1.9.0. + return maskGo(b, key) + /* + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key + */ } // @nhooyr: I am not confident that the amd64 or the arm64 implementations of this @@ -18,4 +22,5 @@ func mask(b []byte, key uint32) uint32 { // See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 // //go:noescape +//lint:ignore U1000 disabled till v1.9.0 func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/netconn.go b/netconn.go index 3324014d..86f7dadb 100644 --- a/netconn.go +++ b/netconn.go @@ -94,7 +94,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. + // These must be first to be aligned on 32 bit platforms. // https://github.com/nhooyr/websocket/pull/438 readExpired int64 writeExpired int64 diff --git a/read.go b/read.go index 81b89831..a59e71d9 100644 --- a/read.go +++ b/read.go @@ -60,14 +60,24 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadDone = make(chan struct{}) + c.closeReadMu.Unlock() - c.wg.Add(1) go func() { - defer c.CloseNow() - defer c.wg.Done() + defer close(c.closeReadDone) defer cancel() + defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") @@ -222,7 +232,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.close(err) return header{}, err } } @@ -251,9 +260,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err + return n, fmt.Errorf("failed to read frame payload: %w", err) } } @@ -308,9 +315,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return nil } - defer func() { - c.readCloseFrameErr = err - }() + // opClose ce, err := parseClosePayload(b) if err != nil { @@ -320,9 +325,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.close(err) + c.readMu.unlock() + c.close() return err } @@ -336,9 +341,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -411,10 +414,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { diff --git a/write.go b/write.go index 7ac7ce63..d7222f2d 100644 --- a/write.go +++ b/write.go @@ -159,7 +159,6 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -242,30 +241,12 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } - - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - c.writeFrameMu.unlock() - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, net.ErrClosed - } - } defer c.writeFrameMu.unlock() select { @@ -283,7 +264,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco err = ctx.Err() default: } - c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() @@ -392,7 +372,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index 77d0d80f..02d61f28 100644 --- a/ws_js.go +++ b/ws_js.go @@ -47,9 +47,10 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit xsync.Int64 - wg sync.WaitGroup + closeReadMu sync.Mutex + closeReadCtx context.Context + closingMu sync.Mutex - isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - if c.isReadClosed.Load() == 1 { + c.closeReadMu.Lock() + closedRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } @@ -225,7 +229,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -239,7 +242,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { - defer c.wg.Wait() return c.Close(StatusGoingAway, "") } @@ -389,14 +391,19 @@ func (w *writer) Close() error { // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) - c.wg.Add(1) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() + go func() { - defer c.CloseNow() - defer c.wg.Done() defer cancel() + defer c.CloseNow() _, _, err := c.read(ctx) if err != nil { c.Close(StatusPolicyViolation, "unexpected data message")