diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c650580..e9b4b5f6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/fmt.sh @@ -19,7 +19,7 @@ jobs: steps: - uses: actions/checkout@v4 - run: go version - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/lint.sh @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/test.sh @@ -41,7 +41,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: ./ci/bench.sh diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index b1e64fbc..2ba9ce34 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/bench.sh @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh @@ -34,7 +34,7 @@ jobs: - uses: actions/checkout@v4 with: ref: dev - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/bench.sh @@ -44,11 +44,11 @@ jobs: - uses: actions/checkout@v4 with: ref: dev - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - run: AUTOBAHN=1 ./ci/test.sh - uses: actions/upload-artifact@v3 with: - name: coverage.html + name: coverage-dev.html path: ./ci/out/coverage.html diff --git a/accept_test.go b/accept_test.go index 7cb85d0f..18233b1e 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "nhooyr.io/websocket/internal/test/assert" @@ -142,6 +143,42 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + t.Run("closeRace", func(t *testing.T) { + t.Parallel() + + server, _ := net.Pipe() + + rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return server, rw, nil + }, + } + } + w := newResponseWriter() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + c, err := Accept(w, r, nil) + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.Close(StatusInternalError, "the sky is falling") + wg.Done() + }() + go func() { + c.CloseNow() + wg.Done() + }() + wg.Wait() + assert.Success(t, err) + }) } func Test_verifyClientHandshake(t *testing.T) { diff --git a/ci/bench.sh b/ci/bench.sh index a553b93a..30c06986 100755 --- a/ci/bench.sh +++ b/ci/bench.sh @@ -2,8 +2,19 @@ set -eu cd -- "$(dirname "$0")/.." -go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" . +go test --run=^$ --bench=. --benchmem "$@" ./... +# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test ( cd ./internal/thirdparty - go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" . + go test --run=^$ --bench=. --benchmem "$@" . + + GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem + fi ) diff --git a/ci/fmt.sh b/ci/fmt.sh index 6e5a68e4..31d0c15d 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -18,3 +18,7 @@ npx prettier@3.0.3 \ $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go + +if [ "${CI-}" ]; then + git diff --exit-code +fi diff --git a/ci/test.sh b/ci/test.sh index 83bb9832..a3007614 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -11,6 +11,19 @@ cd -- "$(dirname "$0")/.." go test "$@" ./... ) +( + GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask + fi +) + + go install github.com/agnivade/wasmbrowsertest@latest go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... sed -i.bak '/stringer\.go/d' ci/out/coverage.prof diff --git a/close.go b/close.go index c3dee7e0..31504b0e 100644 --- a/close.go +++ b/close.go @@ -93,85 +93,110 @@ func CloseStatus(err error) StatusCode { // The connection can only be closed once. Additional calls to Close // are no-ops. // -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. +// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() - return c.closeHandshake(code, reason) +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { - defer c.wg.Wait() - defer errd.Wrap(&err, "failed to close WebSocket") + defer errd.Wrap(&err, "failed to immediately close WebSocket") - if c.isClosed() { + if !c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } return net.ErrClosed } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() - c.close(nil) - return c.closeErr -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() + err = c.close() - if writeErr != nil { - return writeErr + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 } + return err +} - if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { - return closeHandshakeErr +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err } + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return net.ErrClosed - } - ce := CloseError{ Code: code, Reason: reason, } var p []byte - var marshalErr error + var err error if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil + p, err = ce.bytes() + if err != nil { + return err + } } - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - if marshalErr != nil { - return marshalErr + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err } - return writeErr + return nil } func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -181,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { @@ -207,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error { } } +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + closeRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closeRead { + select { + case <-c.closeReadDone: + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -277,16 +328,14 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) setCloseErr(err error) { +func (c *Conn) casClosing() bool { c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil && err != nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + defer c.closeMu.Unlock() + if !c.closing { + c.closing = true + return true } + return false } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index ef4d62ad..8690fb3b 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" "context" - "errors" "fmt" "io" "net" @@ -53,15 +52,15 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader // Write state. msgWriter *msgWriter @@ -70,11 +69,13 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header - wg sync.WaitGroup - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closed chan struct{} + closeMu sync.Mutex + closing bool pingCounter int32 activePingsMu sync.Mutex @@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), @@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close() }) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.timeoutLoop() - }() + go c.timeoutLoop() return c } @@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { - return - } - if err == nil { - err = c.rwc.Close() + return net.ErrClosed } - c.setCloseErrLocked(err) - - close(c.closed) runtime.SetFinalizer(c, nil) + close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. - c.rwc.Close() - - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.msgWriter.close() - c.msgReader.close() - }() + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err } func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() @@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.writeError(StatusPolicyViolation, errors.New("read timed out")) - }() + c.close() + return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close() return } } @@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } @@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected diff --git a/conn_test.go b/conn_test.go index 97b172dc..2b44ad22 100644 --- a/conn_test.go +++ b/conn_test.go @@ -345,6 +345,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.SkipNow() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ @@ -360,7 +363,7 @@ func TestWasm(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() diff --git a/frame.go b/frame.go index 351632fd..d5631863 100644 --- a/frame.go +++ b/frame.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "math" - "math/bits" "nhooyr.io/websocket/internal/errd" ) @@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { return nil } - -// mask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go index e697e198..bd626358 100644 --- a/frame_test.go +++ b/frame_test.go @@ -97,7 +97,7 @@ func Test_mask(t *testing.T) { key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) + gotKey32 := mask(p, key32) expP := []byte{0, 0, 0, 0x0d, 0x6} assert.Equal(t, "p", expP, p) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..e69de29b diff --git a/internal/examples/chat/README.md b/internal/examples/chat/README.md index ca1024a0..574c6994 100644 --- a/internal/examples/chat/README.md +++ b/internal/examples/chat/README.md @@ -5,7 +5,7 @@ This directory contains a full stack example of a simple chat webapp using nhooy ```bash $ cd examples/chat $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser. diff --git a/internal/examples/chat/main.go b/internal/examples/chat/main.go index 3fcec6be..e3432984 100644 --- a/internal/examples/chat/main.go +++ b/internal/examples/chat/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) cs := newChatServer() s := &http.Server{ diff --git a/internal/examples/echo/README.md b/internal/examples/echo/README.md index 7f42c3c5..ac03f640 100644 --- a/internal/examples/echo/README.md +++ b/internal/examples/echo/README.md @@ -5,7 +5,7 @@ This directory contains a echo server example using nhooyr.io/websocket. ```bash $ cd examples/echo $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages diff --git a/internal/examples/echo/main.go b/internal/examples/echo/main.go index 16d78a79..47e30d05 100644 --- a/internal/examples/echo/main.go +++ b/internal/examples/echo/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) s := &http.Server{ Handler: echoServer{ diff --git a/internal/thirdparty/frame_test.go b/internal/thirdparty/frame_test.go index 1a0ed125..89042e53 100644 --- a/internal/thirdparty/frame_test.go +++ b/internal/thirdparty/frame_test.go @@ -2,17 +2,19 @@ package thirdparty import ( "encoding/binary" + "runtime" "strconv" "testing" _ "unsafe" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" + _ "github.com/lesismal/nbio/nbhttp/websocket" _ "nhooyr.io/websocket" ) -func basicMask(maskKey [4]byte, pos int, b []byte) int { +func basicMask(b []byte, maskKey [4]byte, pos int) int { for i := range b { b[i] ^= maskKey[pos&3] pos++ @@ -20,23 +22,34 @@ func basicMask(maskKey [4]byte, pos int, b []byte) int { return pos & 3 } +//go:linkname maskGo nhooyr.io/websocket.maskGo +func maskGo(b []byte, key32 uint32) int + +//go:linkname maskAsm nhooyr.io/websocket.maskAsm +func maskAsm(b *byte, len int, key32 uint32) uint32 + +//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR +func nbioMaskBytes(b, key []byte) int + //go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes func gorillaMaskBytes(key [4]byte, pos int, b []byte) int -//go:linkname mask nhooyr.io/websocket.mask -func mask(key32 uint32, b []byte) int - func Benchmark_mask(b *testing.B) { + b.Run(runtime.GOARCH, benchmark_mask) +} + +func benchmark_mask(b *testing.B) { sizes := []int{ - 2, - 3, - 4, 8, 16, 32, 128, + 256, 512, + 1024, + 2048, 4096, + 8192, 16384, } @@ -48,22 +61,34 @@ func Benchmark_mask(b *testing.B) { name: "basic", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { - basicMask(key, 0, p) + basicMask(p, key, 0) } }, }, { - name: "nhooyr", + name: "nhooyr-go", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + maskGo(p, key32) + } + }, + }, + { + name: "wdvxdr1123-asm", fn: func(b *testing.B, key [4]byte, p []byte) { key32 := binary.LittleEndian.Uint32(key[:]) b.ResetTimer() for i := 0; i < b.N; i++ { - mask(key32, p) + maskAsm(&p[0], len(p), key32) } }, }, + { name: "gorilla", fn: func(b *testing.B, key [4]byte, p []byte) { @@ -80,16 +105,25 @@ func Benchmark_mask(b *testing.B) { } }, }, + { + name: "nbio", + fn: func(b *testing.B, key [4]byte, p []byte) { + keyb := key[:] + for i := 0; i < b.N; i++ { + nbioMaskBytes(p, keyb) + } + }, + }, } key := [4]byte{1, 2, 3, 4} - for _, size := range sizes { - p := make([]byte, size) + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + for _, size := range sizes { + p := make([]byte, size) - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { + b.Run(strconv.Itoa(size), func(b *testing.B) { b.SetBytes(int64(size)) fn.fn(b, key, p) diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod index 10eb45c1..d991dd64 100644 --- a/internal/thirdparty/go.mod +++ b/internal/thirdparty/go.mod @@ -8,6 +8,7 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/gobwas/ws v1.3.0 github.com/gorilla/websocket v1.5.0 + github.com/lesismal/nbio v1.3.18 nhooyr.io/websocket v0.0.0-00010101000000-000000000000 ) @@ -25,6 +26,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect + github.com/lesismal/llib v1.1.12 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -34,7 +36,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.9.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum index a9424b8d..1f542103 100644 --- a/internal/thirdparty/go.sum +++ b/internal/thirdparty/go.sum @@ -41,6 +41,10 @@ github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZX github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lesismal/llib v1.1.12 h1:KJFB8bL02V+QGIvILEw/w7s6bKj9Ps9Px97MZP2EOk0= +github.com/lesismal/llib v1.1.12/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= +github.com/lesismal/nbio v1.3.18 h1:kmJZlxjQpVfuCPYcXdv0Biv9LHVViJZet5K99Xs3RAs= +github.com/lesismal/nbio v1.3.18/go.mod h1:KWlouFT5cgDdW5sMX8RsHASUMGniea9X0XIellZ0B38= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -67,19 +71,51 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/mask.go b/mask.go new file mode 100644 index 00000000..7bc0c8d5 --- /dev/null +++ b/mask.go @@ -0,0 +1,128 @@ +package websocket + +import ( + "encoding/binary" + "math/bits" +) + +// maskGo applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func maskGo(b []byte, key uint32) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/mask_amd64.s b/mask_amd64.s new file mode 100644 index 00000000..bd42be31 --- /dev/null +++ b/mask_amd64.s @@ -0,0 +1,127 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // AX = b + // CX = len (left length) + // SI = key (uint32) + // DI = uint64(SI) | uint64(SI)<<32 + MOVQ b+0(FP), AX + MOVQ len+8(FP), CX + MOVL key+16(FP), SI + + // calculate the DI + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + CMPQ CX, $15 + JLE less_than_16 + CMPQ CX, $63 + JLE less_than_64 + CMPQ CX, $128 + JLE sse + TESTQ $31, AX + JNZ unaligned + +unaligned_loop_1byte: + XORB SI, (AX) + INCQ AX + DECQ CX + ROLL $24, SI + TESTQ $7, AX + JNZ unaligned_loop_1byte + + // calculate DI again since SI was modified + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + TESTQ $31, AX + JZ sse + +unaligned: + TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. + JNZ unaligned_loop_1byte + +unaligned_loop: + // we don't need to check the CX since we know it's above 128 + XORQ DI, (AX) + ADDQ $8, AX + SUBQ $8, CX + TESTQ $31, AX + JNZ unaligned_loop + JMP sse + +sse: + CMPQ CX, $0x40 + JL less_than_64 + MOVQ DI, X0 + PUNPCKLQDQ X0, X0 + +sse_loop: + MOVOU 0*16(AX), X1 + MOVOU 1*16(AX), X2 + MOVOU 2*16(AX), X3 + MOVOU 3*16(AX), X4 + PXOR X0, X1 + PXOR X0, X2 + PXOR X0, X3 + PXOR X0, X4 + MOVOU X1, 0*16(AX) + MOVOU X2, 1*16(AX) + MOVOU X3, 2*16(AX) + MOVOU X4, 3*16(AX) + ADDQ $0x40, AX + SUBQ $0x40, CX + CMPQ CX, $0x40 + JAE sse_loop + +less_than_64: + TESTQ $32, CX + JZ less_than_32 + XORQ DI, (AX) + XORQ DI, 8(AX) + XORQ DI, 16(AX) + XORQ DI, 24(AX) + ADDQ $32, AX + +less_than_32: + TESTQ $16, CX + JZ less_than_16 + XORQ DI, (AX) + XORQ DI, 8(AX) + ADDQ $16, AX + +less_than_16: + TESTQ $8, CX + JZ less_than_8 + XORQ DI, (AX) + ADDQ $8, AX + +less_than_8: + TESTQ $4, CX + JZ less_than_4 + XORL SI, (AX) + ADDQ $4, AX + +less_than_4: + TESTQ $2, CX + JZ less_than_2 + XORW SI, (AX) + ROLL $16, SI + ADDQ $2, AX + +less_than_2: + TESTQ $1, CX + JZ done + XORB SI, (AX) + ROLL $24, SI + +done: + MOVL SI, ret+24(FP) + RET diff --git a/mask_arm64.s b/mask_arm64.s new file mode 100644 index 00000000..e494b43a --- /dev/null +++ b/mask_arm64.s @@ -0,0 +1,72 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // R0 = b + // R1 = len + // R3 = key (uint32) + // R2 = uint64(key)<<32 | uint64(key) + MOVD b_ptr+0(FP), R0 + MOVD b_len+8(FP), R1 + MOVWU key+16(FP), R3 + MOVD R3, R2 + ORR R2<<32, R2, R2 + VDUP R2, V0.D2 + CMP $64, R1 + BLT less_than_64 + +loop_64: + VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VEOR V3.B16, V0.B16, V3.B16 + VEOR V4.B16, V0.B16, V4.B16 + VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) + SUBS $64, R1 + CMP $64, R1 + BGE loop_64 + +less_than_64: + CBZ R1, end + TBZ $5, R1, less_than_32 + VLD1 (R0), [V1.B16, V2.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VST1.P [V1.B16, V2.B16], 32(R0) + +less_than_32: + TBZ $4, R1, less_than_16 + LDP (R0), (R11, R12) + EOR R11, R2, R11 + EOR R12, R2, R12 + STP.P (R11, R12), 16(R0) + +less_than_16: + TBZ $3, R1, less_than_8 + MOVD (R0), R11 + EOR R2, R11, R11 + MOVD.P R11, 8(R0) + +less_than_8: + TBZ $2, R1, less_than_4 + MOVWU (R0), R11 + EORW R2, R11, R11 + MOVWU.P R11, 4(R0) + +less_than_4: + TBZ $1, R1, less_than_2 + MOVHU (R0), R11 + EORW R3, R11, R11 + MOVHU.P R11, 2(R0) + RORW $16, R3 + +less_than_2: + TBZ $0, R1, end + MOVBU (R0), R11 + EORW R3, R11, R11 + MOVBU.P R11, 1(R0) + RORW $8, R3 + +end: + MOVWU R3, ret+24(FP) + RET diff --git a/mask_asm.go b/mask_asm.go new file mode 100644 index 00000000..f9484b5b --- /dev/null +++ b/mask_asm.go @@ -0,0 +1,26 @@ +//go:build amd64 || arm64 + +package websocket + +func mask(b []byte, key uint32) uint32 { + // TODO: Will enable in v1.9.0. + return maskGo(b, key) + /* + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key + */ +} + +// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this +// function are perfect. There are almost certainly missing optimizations or +// opportunities for simplification. I'm confident there are no bugs though. +// For example, the arm64 implementation doesn't align memory like the amd64. +// Or the amd64 implementation could use AVX512 instead of just AVX2. +// The AVX2 code I had to disable anyway as it wasn't performing as expected. +// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 +// +//go:noescape +//lint:ignore U1000 disabled till v1.9.0 +func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/mask_asm_test.go b/mask_asm_test.go new file mode 100644 index 00000000..416cbc43 --- /dev/null +++ b/mask_asm_test.go @@ -0,0 +1,11 @@ +//go:build amd64 || arm64 + +package websocket + +import "testing" + +func TestMaskASM(t *testing.T) { + t.Parallel() + + testMask(t, "maskASM", mask) +} diff --git a/mask_go.go b/mask_go.go new file mode 100644 index 00000000..b29435e9 --- /dev/null +++ b/mask_go.go @@ -0,0 +1,7 @@ +//go:build !amd64 && !arm64 && !js + +package websocket + +func mask(b []byte, key uint32) uint32 { + return maskGo(b, key) +} diff --git a/mask_test.go b/mask_test.go new file mode 100644 index 00000000..54f55e43 --- /dev/null +++ b/mask_test.go @@ -0,0 +1,73 @@ +package websocket + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "math/big" + "math/bits" + "testing" + + "nhooyr.io/websocket/internal/test/assert" +) + +func basicMask(b []byte, key uint32) uint32 { + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + return key +} + +func basicMask2(b []byte, key uint32) uint32 { + keyb := binary.LittleEndian.AppendUint32(nil, key) + pos := 0 + for i := range b { + b[i] ^= keyb[pos&3] + pos++ + } + return bits.RotateLeft32(key, (pos&3)*-8) +} + +func TestMask(t *testing.T) { + t.Parallel() + + testMask(t, "basicMask", basicMask) + testMask(t, "maskGo", maskGo) + testMask(t, "basicMask2", basicMask2) +} + +func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) { + t.Run(name, func(t *testing.T) { + t.Parallel() + for i := 0; i < 9999; i++ { + keyb := make([]byte, 4) + _, err := rand.Read(keyb) + assert.Success(t, err) + key := binary.LittleEndian.Uint32(keyb) + + n, err := rand.Int(rand.Reader, big.NewInt(1<<16)) + assert.Success(t, err) + + b := make([]byte, 1+n.Int64()) + _, err = rand.Read(b) + assert.Success(t, err) + + b2 := make([]byte, len(b)) + copy(b2, b) + b3 := make([]byte, len(b)) + copy(b3, b) + + key2 := basicMask(b2, key) + key3 := fn(b3, key) + + if key2 != key3 { + t.Errorf("expected key %X but got %X", key2, key3) + } + if !bytes.Equal(b2, b3) { + t.Error("bad bytes") + return + } + } + }) +} diff --git a/netconn.go b/netconn.go index 1667f45c..86f7dadb 100644 --- a/netconn.go +++ b/netconn.go @@ -94,22 +94,25 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { + // These must be first to be aligned on 32 bit platforms. + // https://github.com/nhooyr/websocket/pull/438 + readExpired int64 + writeExpired int64 + c *Conn msgType MessageType - writeTimer *time.Timer - writeMu *mu - writeExpired int64 - writeCtx context.Context - writeCancel context.CancelFunc - - readTimer *time.Timer - readMu *mu - readExpired int64 - readCtx context.Context - readCancel context.CancelFunc - readEOFed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} diff --git a/read.go b/read.go index 8742842e..a59e71d9 100644 --- a/read.go +++ b/read.go @@ -60,14 +60,24 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadDone = make(chan struct{}) + c.closeReadMu.Unlock() - c.wg.Add(1) go func() { - defer c.CloseNow() - defer c.wg.Done() + defer close(c.closeReadDone) defer cancel() + defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") @@ -222,7 +232,6 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.close(err) return header{}, err } } @@ -251,9 +260,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err + return n, fmt.Errorf("failed to read frame payload: %w", err) } } @@ -289,7 +296,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } if h.masked { - mask(h.maskKey, b) + mask(b, h.maskKey) } switch h.opcode { @@ -308,9 +315,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return nil } - defer func() { - c.readCloseFrameErr = err - }() + // opClose ce, err := parseClosePayload(b) if err != nil { @@ -320,9 +325,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) - c.close(err) + c.readMu.unlock() + c.close() return err } @@ -336,9 +341,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -411,10 +414,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { @@ -453,7 +455,7 @@ func (mr *msgReader) read(p []byte) (int, error) { mr.payloadLength -= int64(n) if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + mr.maskKey = mask(p, mr.maskKey) } return n, nil diff --git a/write.go b/write.go index 7b1152ce..d7222f2d 100644 --- a/write.go +++ b/write.go @@ -159,7 +159,6 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -242,30 +241,12 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } - - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - c.writeFrameMu.unlock() - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, net.ErrClosed - } - } defer c.writeFrameMu.unlock() select { @@ -283,7 +264,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco err = ctx.Err() default: } - c.close(err) err = fmt.Errorf("failed to write frame: %w", err) } }() @@ -365,7 +345,7 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, err } - maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey) p = p[j:] n += j @@ -392,7 +372,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index 77d0d80f..02d61f28 100644 --- a/ws_js.go +++ b/ws_js.go @@ -47,9 +47,10 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit xsync.Int64 - wg sync.WaitGroup + closeReadMu sync.Mutex + closeReadCtx context.Context + closingMu sync.Mutex - isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - if c.isReadClosed.Load() == 1 { + c.closeReadMu.Lock() + closedRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } @@ -225,7 +229,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -239,7 +242,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { - defer c.wg.Wait() return c.Close(StatusGoingAway, "") } @@ -389,14 +391,19 @@ func (w *writer) Close() error { // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) - c.wg.Add(1) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() + go func() { - defer c.CloseNow() - defer c.wg.Done() defer cancel() + defer c.CloseNow() _, _, err := c.read(ctx) if err != nil { c.Close(StatusPolicyViolation, "unexpected data message") diff --git a/wsjson/wsjson_test.go b/wsjson/wsjson_test.go new file mode 100644 index 00000000..080ab38d --- /dev/null +++ b/wsjson/wsjson_test.go @@ -0,0 +1,53 @@ +package wsjson_test + +import ( + "encoding/json" + "io" + "strconv" + "testing" + + "nhooyr.io/websocket/internal/test/xrand" +) + +func BenchmarkJSON(b *testing.B) { + sizes := []int{ + 8, + 16, + 32, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + } + + b.Run("json.Encoder", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.NewEncoder(io.Discard).Encode(msg) + } + }) + } + }) + b.Run("json.Marshal", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.Marshal(msg) + } + }) + } + }) +}