From 0daf07fec93be7da18a697ce5fdd7253f0d14c2f Mon Sep 17 00:00:00 2001 From: Moritz Richter Date: Thu, 5 Sep 2024 11:24:19 +0200 Subject: [PATCH 01/11] adds a lock on writeClose. closes #448 --- close.go | 7 +++++++ conn.go | 7 ++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/close.go b/close.go index ff2e878a..129bc824 100644 --- a/close.go +++ b/close.go @@ -169,6 +169,12 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error { } func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + defer c.closeMu.Unlock() + if c.closeWritten { + return nil + } + ce := CloseError{ Code: code, Reason: reason, @@ -193,6 +199,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { if err != nil && !errors.Is(err, net.ErrClosed) { return err } + c.closeWritten = true return nil } diff --git a/conn.go b/conn.go index d7434a9d..2aa8290e 100644 --- a/conn.go +++ b/conn.go @@ -73,9 +73,10 @@ type Conn struct { closeReadCtx context.Context closeReadDone chan struct{} - closed chan struct{} - closeMu sync.Mutex - closing bool + closed chan struct{} + closeMu sync.Mutex + closing bool + closeWritten bool pingCounter atomic.Int64 activePingsMu sync.Mutex From 87caf5869808995d6135b995a8085da22f5483f3 Mon Sep 17 00:00:00 2001 From: Moritz Richter Date: Fri, 6 Sep 2024 08:43:06 +0200 Subject: [PATCH 02/11] moves the writeLock to writeFrame to avoid deadlock --- close.go | 7 ------- conn.go | 8 ++++---- write.go | 11 +++++++++++ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/close.go b/close.go index 129bc824..ff2e878a 100644 --- a/close.go +++ b/close.go @@ -169,12 +169,6 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error { } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - defer c.closeMu.Unlock() - if c.closeWritten { - return nil - } - ce := CloseError{ Code: code, Reason: reason, @@ -199,7 +193,6 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { if err != nil && !errors.Is(err, net.ErrClosed) { return err } - c.closeWritten = true return nil } diff --git a/conn.go b/conn.go index 2aa8290e..2f7aae2c 100644 --- a/conn.go +++ b/conn.go @@ -73,10 +73,10 @@ type Conn struct { closeReadCtx context.Context closeReadDone chan struct{} - closed chan struct{} - closeMu sync.Mutex - closing bool - closeWritten bool + closed chan struct{} + closeMu sync.Mutex + closing bool + closeSent bool pingCounter atomic.Int64 activePingsMu sync.Mutex diff --git a/write.go b/write.go index e294a680..c4c96f66 100644 --- a/write.go +++ b/write.go @@ -20,6 +20,9 @@ import ( "github.com/coder/websocket/internal/util" ) +// ErrAlreadyClosed is returned when a close frame has already been sent and another is attempted. +var ErrAlreadyClosed = errors.New("close frame already sent") + // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // @@ -268,6 +271,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } }() + if opcode == opClose && c.closeSent { + return 0, ErrAlreadyClosed + } + c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) @@ -303,6 +310,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } + if opcode == opClose { + c.closeSent = true + } + select { case <-c.closed: if opcode == opClose { From e70e06008b007314fb8637099110c7eb88fd343d Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 6 Sep 2024 09:41:54 +0000 Subject: [PATCH 03/11] group closeSent with writeFrameMu --- conn.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 2f7aae2c..f68bbaa0 100644 --- a/conn.go +++ b/conn.go @@ -68,15 +68,15 @@ type Conn struct { writeBuf []byte writeHeaderBuf [8]byte writeHeader header + closeSent bool closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} - closed chan struct{} - closeMu sync.Mutex - closing bool - closeSent bool + closed chan struct{} + closeMu sync.Mutex + closing bool pingCounter atomic.Int64 activePingsMu sync.Mutex From 2ea151ac783f11047c4ef06b9718823f6d126971 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 6 Sep 2024 09:43:19 +0000 Subject: [PATCH 04/11] return close sent for all writes, remove sentinel error --- write.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/write.go b/write.go index c4c96f66..1a0995f2 100644 --- a/write.go +++ b/write.go @@ -20,9 +20,6 @@ import ( "github.com/coder/websocket/internal/util" ) -// ErrAlreadyClosed is returned when a close frame has already been sent and another is attempted. -var ErrAlreadyClosed = errors.New("close frame already sent") - // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // @@ -252,6 +249,15 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() + if c.closeSent { + select { + case <-c.closed: + return 0, net.ErrClosed + default: + } + return 0, errors.New("close sent") + } + select { case <-c.closed: return 0, net.ErrClosed @@ -271,10 +277,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } }() - if opcode == opClose && c.closeSent { - return 0, ErrAlreadyClosed - } - c.writeHeader.fin = fin c.writeHeader.opcode = opcode c.writeHeader.payloadLength = int64(len(p)) From c3613fc66341cffa20fc9eacb3017bec47a50f59 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Tue, 17 Sep 2024 13:30:14 +0000 Subject: [PATCH 05/11] fix: echo read error after close received --- close.go | 4 ++++ conn.go | 2 ++ read.go | 19 ++++++++++++++++++- write.go | 4 +++- 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/close.go b/close.go index ff2e878a..925e316d 100644 --- a/close.go +++ b/close.go @@ -206,6 +206,10 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() + if c.readCloseErr != nil { + return c.readCloseErr + } + for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { diff --git a/conn.go b/conn.go index f68bbaa0..95ee33bf 100644 --- a/conn.go +++ b/conn.go @@ -61,6 +61,7 @@ type Conn struct { readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader + readCloseErr error // Write state. msgWriter *msgWriter @@ -70,6 +71,7 @@ type Conn struct { writeHeader header closeSent bool + // CloseRead state. closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} diff --git a/read.go b/read.go index e2699da5..b2b7731c 100644 --- a/read.go +++ b/read.go @@ -181,6 +181,15 @@ func (c *Conn) readRSV1Illegal(h header) bool { } func (c *Conn) readLoop(ctx context.Context) (header, error) { + if c.readCloseErr != nil { + select { + case <-c.closed: + return header{}, net.ErrClosed + default: + } + return header{}, c.readCloseErr + } + for { h, err := c.readFrameHeader(ctx) if err != nil { @@ -324,8 +333,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return err } + if c.readCloseErr == nil { + c.readCloseErr = ce + } + err = fmt.Errorf("received close frame: %w", ce) - c.writeClose(ce.Code, ce.Reason) + if err2 := c.writeClose(ce.Code, ce.Reason); errors.Is(err2, errCloseSent) { + // The close handshake has already been initiated, connection + // close should be handled elsewhere. + return err + } c.readMu.unlock() c.close() return err diff --git a/write.go b/write.go index 1a0995f2..e19628d4 100644 --- a/write.go +++ b/write.go @@ -241,6 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } +var errCloseSent = errors.New("close sent") + // writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) @@ -255,7 +257,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return 0, net.ErrClosed default: } - return 0, errors.New("close sent") + return 0, errCloseSent } select { From cc1c15a162e1b50b3d0af047d86fb8117b857761 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 29 Nov 2024 12:58:59 +0000 Subject: [PATCH 06/11] fix: rewrite close handshake flow when initiated from other side --- close.go | 16 ++---- conn.go | 11 ++-- conn_test.go | 147 +++++++++++++++++++++++++++++++++++++++++++++++++++ read.go | 97 ++++++++++++++++----------------- write.go | 53 +++++++++++-------- 5 files changed, 234 insertions(+), 90 deletions(-) diff --git a/close.go b/close.go index 925e316d..f94951dc 100644 --- a/close.go +++ b/close.go @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -206,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() - if c.readCloseErr != nil { - return c.readCloseErr - } - for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { @@ -333,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) { } func (c *Conn) casClosing() bool { - c.closeMu.Lock() - defer c.closeMu.Unlock() - if !c.closing { - c.closing = true - return true - } - return false + return c.closing.Swap(true) } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index 95ee33bf..4fd4ee0e 100644 --- a/conn.go +++ b/conn.go @@ -61,7 +61,6 @@ type Conn struct { readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader - readCloseErr error // Write state. msgWriter *msgWriter @@ -69,16 +68,20 @@ type Conn struct { writeBuf []byte writeHeaderBuf [8]byte writeHeader header - closeSent bool + + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error // CloseRead state. closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} - closed chan struct{} + closing atomic.Bool closeMu sync.Mutex - closing bool + closed chan struct{} pingCounter atomic.Int64 activePingsMu sync.Mutex diff --git a/conn_test.go b/conn_test.go index b4d57f21..2a4b266d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "os" @@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) { }() } } + +func TestConnClosePropagation(t *testing.T) { + t.Parallel() + + want := []byte("hello") + keepWriting := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + err := c.Write(context.Background(), websocket.MessageText, want) + if err != nil { + return err + } + } + }) + } + keepReading := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + _, got, err := c.Read(context.Background()) + if err != nil { + return err + } + if !bytes.Equal(want, got) { + return fmt.Errorf("unexpected message: want %q, got %q", want, got) + } + } + }) + } + checkReadErr := func(t *testing.T, err error) { + var ce websocket.CloseError + if errors.As(err, &ce) { + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) + } else { + assert.ErrorIs(t, net.ErrClosed, err) + } + } + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { + for _, c := range conn { + // Check write error. + err := c.Write(context.Background(), websocket.MessageText, want) + assert.ErrorIs(t, net.ErrClosed, err) + + // Check read error (output depends on when read is called in relation to connection closure). + _, _, err = c.Read(context.Background()) + checkReadErr(t, err) + } + } + + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + + _, got, err := other.Read(tt.ctx) + assert.Success(t, err) + assert.Equal(t, "msg", want, got) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + otherReadErr := keepReading(other) + + err := this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = other.CloseRead(tt.ctx) + errs := keepReading(this) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-errs: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + thisReadErr := keepReading(this) + otherReadErr := keepReading(other) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) +} diff --git a/read.go b/read.go index b2b7731c..b86ab139 100644 --- a/read.go +++ b/read.go @@ -181,15 +181,6 @@ func (c *Conn) readRSV1Illegal(h header) bool { } func (c *Conn) readLoop(ctx context.Context) (header, error) { - if c.readCloseErr != nil { - select { - case <-c.closed: - return header{}, net.ErrClosed - default: - } - return header{}, c.readCloseErr - } - for { h, err := c.readFrameHeader(ctx) if err != nil { @@ -226,57 +217,59 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { +func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: - return header{}, net.ErrClosed + return nil, net.ErrClosed case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) - if err != nil { + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + return nil, closeReceivedErr + } + + return func() { select { case <-c.closed: - return header{}, net.ErrClosed - case <-ctx.Done(): - return header{}, ctx.Err() - default: - return header{}, err + if *err != nil { + *err = net.ErrClosed + } + case c.writeTimeout <- context.Background(): } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() + } + }, nil +} + +func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return header{}, err } + defer readDone() - select { - case <-c.closed: - return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) + if err != nil { + return header{}, err } return h, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - select { - case <-c.closed: - return 0, net.ErrClosed - case c.readTimeout <- ctx: +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return 0, err } + defer readDone() n, err := io.ReadFull(c.br, p) if err != nil { - select { - case <-c.closed: - return n, net.ErrClosed - case <-ctx.Done(): - return n, ctx.Err() - default: - return n, fmt.Errorf("failed to read frame payload: %w", err) - } - } - - select { - case <-c.closed: - return n, net.ErrClosed - case c.readTimeout <- context.Background(): + return n, fmt.Errorf("failed to read frame payload: %w", err) } return n, err @@ -333,18 +326,20 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return err } - if c.readCloseErr == nil { - c.readCloseErr = ce - } - err = fmt.Errorf("received close frame: %w", ce) - if err2 := c.writeClose(ce.Code, ce.Reason); errors.Is(err2, errCloseSent) { - // The close handshake has already been initiated, connection - // close should be handled elsewhere. - return err + c.closeStateMu.Lock() + c.closeReceivedErr = err + closeSent := c.closeSentErr != nil + c.closeStateMu.Unlock() + + if !closeSent { + c.readMu.unlock() + _ = c.writeClose(ce.Code, ce.Reason) + } + if !c.casClosing() { + c.readMu.unlock() + _ = c.close() } - c.readMu.unlock() - c.close() return err } diff --git a/write.go b/write.go index e19628d4..7a59b5c0 100644 --- a/write.go +++ b/write.go @@ -241,8 +241,6 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -var errCloseSent = errors.New("close sent") - // writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) @@ -251,13 +249,22 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() - if c.closeSent { - select { - case <-c.closed: - return 0, net.ErrClosed - default: + defer func() { + if err != nil { + if ctx.Err() != nil { + err = ctx.Err() + } else if c.isClosed() { + err = net.ErrClosed + } + err = fmt.Errorf("failed to write frame: %w", err) } - return 0, errCloseSent + }() + + c.closeStateMu.Lock() + closeSentErr := c.closeSentErr + c.closeStateMu.Unlock() + if closeSentErr != nil { + return 0, net.ErrClosed } select { @@ -265,17 +272,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return 0, net.ErrClosed case c.writeTimeout <- ctx: } - defer func() { - if err != nil { - select { - case <-c.closed: - err = net.ErrClosed - case <-ctx.Done(): - err = ctx.Err() - default: - } - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + err = net.ErrClosed + case c.writeTimeout <- context.Background(): } }() @@ -314,10 +315,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - if opcode == opClose { - c.closeSent = true - } - select { case <-c.closed: if opcode == opClose { @@ -327,6 +324,18 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco case c.writeTimeout <- context.Background(): } + if opcode == opClose { + c.closeStateMu.Lock() + c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) + closeReceived := c.closeReceivedErr != nil + c.closeStateMu.Unlock() + + if closeReceived && !c.casClosing() { + c.writeFrameMu.unlock() + _ = c.close() + } + } + return n, nil } From 038d6d0e3ba27953a009fc2e626fc5850fa248fa Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 29 Nov 2024 13:50:43 +0000 Subject: [PATCH 07/11] fix timeout --- read.go | 23 +++++++++++++---------- write.go | 16 ++++------------ 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/read.go b/read.go index b86ab139..ecd5a47b 100644 --- a/read.go +++ b/read.go @@ -224,25 +224,28 @@ func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { case c.readTimeout <- ctx: } - c.closeStateMu.Lock() - closeReceivedErr := c.closeReceivedErr - c.closeStateMu.Unlock() - if closeReceivedErr != nil { - return nil, closeReceivedErr - } - - return func() { + done := func() { select { case <-c.closed: if *err != nil { *err = net.ErrClosed } - case c.writeTimeout <- context.Background(): + case c.readTimeout <- context.Background(): } if *err != nil && ctx.Err() != nil { *err = ctx.Err() } - }, nil + } + + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + defer done() + return nil, closeReceivedErr + } + + return done, nil } func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { diff --git a/write.go b/write.go index 7a59b5c0..53afe23e 100644 --- a/write.go +++ b/write.go @@ -5,6 +5,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -14,8 +15,6 @@ import ( "net" "time" - "compress/flate" - "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) @@ -250,6 +249,9 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer c.writeFrameMu.unlock() defer func() { + if errors.Is(err, net.ErrClosed) && opcode == opClose { + err = nil + } if err != nil { if ctx.Err() != nil { err = ctx.Err() @@ -275,7 +277,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer func() { select { case <-c.closed: - err = net.ErrClosed case c.writeTimeout <- context.Background(): } }() @@ -315,15 +316,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - select { - case <-c.closed: - if opcode == opClose { - return n, nil - } - return n, net.ErrClosed - case c.writeTimeout <- context.Background(): - } - if opcode == opClose { c.closeStateMu.Lock() c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) From 9cd2adf8f2c9c69081cc5e7ea2d6a5068fdc18b1 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 29 Nov 2024 13:53:23 +0000 Subject: [PATCH 08/11] adjust opclose err on closed --- write.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/write.go b/write.go index 53afe23e..7324de74 100644 --- a/write.go +++ b/write.go @@ -249,7 +249,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer c.writeFrameMu.unlock() defer func() { - if errors.Is(err, net.ErrClosed) && opcode == opClose { + if c.isClosed() && opcode == opClose { err = nil } if err != nil { From e8f5b511297d15bc12eb3252b83bce1d7963c9cf Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 29 Nov 2024 13:54:50 +0000 Subject: [PATCH 09/11] move disconnected comment --- conn_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn_test.go b/conn_test.go index 2a4b266d..9ed8c7ea 100644 --- a/conn_test.go +++ b/conn_test.go @@ -461,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } func BenchmarkConn(b *testing.B) { - var benchCases = []struct { + benchCases := []struct { name string mode websocket.CompressionMode }{ @@ -655,6 +655,7 @@ func TestConnClosePropagation(t *testing.T) { }) } checkReadErr := func(t *testing.T, err error) { + // Check read error (output depends on when read is called in relation to connection closure). var ce websocket.CloseError if errors.As(err, &ce) { assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) @@ -668,7 +669,6 @@ func TestConnClosePropagation(t *testing.T) { err := c.Write(context.Background(), websocket.MessageText, want) assert.ErrorIs(t, net.ErrClosed, err) - // Check read error (output depends on when read is called in relation to connection closure). _, _, err = c.Read(context.Background()) checkReadErr(t, err) } From e410eabffc94076069898797b535f78d9c821abc Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Fri, 29 Nov 2024 23:25:58 +0200 Subject: [PATCH 10/11] add mutex comment --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 4fd4ee0e..76b057dd 100644 --- a/conn.go +++ b/conn.go @@ -80,7 +80,7 @@ type Conn struct { closeReadDone chan struct{} closing atomic.Bool - closeMu sync.Mutex + closeMu sync.Mutex // Protects following. closed chan struct{} pingCounter atomic.Int64 From 930432faaae0475c7f1a9f2a410f6070a28fe8e7 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Mon, 2 Dec 2024 14:23:41 +0000 Subject: [PATCH 11/11] comment prepareRead and readMu unlocking --- read.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/read.go b/read.go index ecd5a47b..1267b5b9 100644 --- a/read.go +++ b/read.go @@ -217,6 +217,12 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } +// prepareRead sets the readTimeout context and returns a done function +// to be called after the read is done. It also returns an error if the +// connection is closed. The reference to the error is used to assign +// an error depending on if the connection closed or the context timed +// out during use. Typically the referenced error is a named return +// variable of the function calling this method. func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: @@ -335,6 +341,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { closeSent := c.closeSentErr != nil c.closeStateMu.Unlock() + // Only unlock readMu if this connection is being closed becaue + // c.close will try to acquire the readMu lock. We unlock for + // writeClose as well because it may also call c.close. if !closeSent { c.readMu.unlock() _ = c.writeClose(ce.Code, ce.Reason)