Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid writing messages after close and improve handshake #476

Merged
merged 12 commits into from
Dec 4, 2024
12 changes: 3 additions & 9 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
func (c *Conn) Close(code StatusCode, reason string) (err error) {
mafredri marked this conversation as resolved.
Show resolved Hide resolved
defer errd.Wrap(&err, "failed to close WebSocket")

if !c.casClosing() {
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
Expand Down Expand Up @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")

if !c.casClosing() {
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
Expand Down Expand Up @@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
}

func (c *Conn) casClosing() bool {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if !c.closing {
c.closing = true
return true
}
return false
return c.closing.Swap(true)
}

func (c *Conn) isClosed() bool {
Expand Down
10 changes: 8 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,19 @@ type Conn struct {
writeHeaderBuf [8]byte
writeHeader header

// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error

// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}

closing atomic.Bool
johnstcn marked this conversation as resolved.
Show resolved Hide resolved
closeMu sync.Mutex // Protects following.
closed chan struct{}
closeMu sync.Mutex
closing bool

pingCounter atomic.Int64
activePingsMu sync.Mutex
Expand Down
149 changes: 148 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
}

func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
benchCases := []struct {
name string
mode websocket.CompressionMode
}{
Expand Down Expand Up @@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
}()
}
}

func TestConnClosePropagation(t *testing.T) {
t.Parallel()

want := []byte("hello")
keepWriting := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
err := c.Write(context.Background(), websocket.MessageText, want)
if err != nil {
return err
}
}
})
}
keepReading := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
_, got, err := c.Read(context.Background())
if err != nil {
return err
}
if !bytes.Equal(want, got) {
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
}
}
})
}
checkReadErr := func(t *testing.T, err error) {
// Check read error (output depends on when read is called in relation to connection closure).
var ce websocket.CloseError
if errors.As(err, &ce) {
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
} else {
assert.ErrorIs(t, net.ErrClosed, err)
}
}
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
for _, c := range conn {
// Check write error.
err := c.Write(context.Background(), websocket.MessageText, want)
assert.ErrorIs(t, net.ErrClosed, err)

_, _, err = c.Read(context.Background())
checkReadErr(t, err)
}
}

t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)

_, got, err := other.Read(tt.ctx)
assert.Success(t, err)
assert.Equal(t, "msg", want, got)

err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
otherReadErr := keepReading(other)

err := this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = other.CloseRead(tt.ctx)
errs := keepReading(this)

err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)

err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-errs:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

thisReadErr := keepReading(this)
otherReadErr := keepReading(other)

err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)

err = this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
}
85 changes: 50 additions & 35 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,57 +217,62 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
}
}

func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
mafredri marked this conversation as resolved.
Show resolved Hide resolved
select {
case <-c.closed:
return header{}, net.ErrClosed
return nil, net.ErrClosed
case c.readTimeout <- ctx:
}

h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
done := func() {
select {
case <-c.closed:
return header{}, net.ErrClosed
case <-ctx.Done():
return header{}, ctx.Err()
default:
return header{}, err
if *err != nil {
*err = net.ErrClosed
}
case c.readTimeout <- context.Background():
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}

select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- context.Background():
c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
mafredri marked this conversation as resolved.
Show resolved Hide resolved
return nil, closeReceivedErr
}

return h, nil
return done, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, net.ErrClosed
case c.readTimeout <- ctx:
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return header{}, err
}
defer readDone()

n, err := io.ReadFull(c.br, p)
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return n, net.ErrClosed
case <-ctx.Done():
return n, ctx.Err()
default:
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
return header{}, err
}

select {
case <-c.closed:
return n, net.ErrClosed
case c.readTimeout <- context.Background():
return h, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return 0, err
}
defer readDone()

n, err := io.ReadFull(c.br, p)
if err != nil {
return n, fmt.Errorf("failed to read frame payload: %w", err)
}

return n, err
Expand Down Expand Up @@ -325,9 +330,19 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
}

err = fmt.Errorf("received close frame: %w", ce)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close()
c.closeStateMu.Lock()
c.closeReceivedErr = err
closeSent := c.closeSentErr != nil
c.closeStateMu.Unlock()

if !closeSent {
c.readMu.unlock()
_ = c.writeClose(ce.Code, ce.Reason)
}
if !c.casClosing() {
c.readMu.unlock()
_ = c.close()
}
return err
mafredri marked this conversation as resolved.
Show resolved Hide resolved
}

Expand Down
Loading