Skip to content

Commit

Permalink
allow exchanging error codes on session termination
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 22, 2024
1 parent 53ef582 commit 273d2b4
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 94 deletions.
68 changes: 3 additions & 65 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,71 +3,7 @@ package yamux
import (
"encoding/binary"
"fmt"
)

type Error struct {
msg string
timeout, temporary bool
}

func (ye *Error) Error() string {
return ye.msg
}

func (ye *Error) Timeout() bool {
return ye.timeout
}

func (ye *Error) Temporary() bool {
return ye.temporary
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
ErrInvalidVersion = &Error{msg: "invalid protocol version"}

// ErrInvalidMsgType means we received a frame with an
// invalid message type
ErrInvalidMsgType = &Error{msg: "invalid msg type"}

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
ErrStreamsExhausted = &Error{msg: "streams exhausted"}

// ErrDuplicateStream is used if a duplicate stream is
// opened inbound
ErrDuplicateStream = &Error{msg: "duplicate stream initiated"}

// ErrReceiveWindowExceeded indicates the window was exceeded
ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"}

// ErrTimeout is used when we reach an IO deadline
ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true}

// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = &Error{msg: "stream closed"}

// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
ErrStreamReset = &Error{msg: "stream reset"}

// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
// timeout writing to the underlying stream connection.
ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true}

// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true}
"time"
)

const (
Expand Down Expand Up @@ -117,6 +53,8 @@ const (
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
// goAwayWaitTime is the time we'll wait to send a goaway frame on close
goAwayWaitTime = 5 * time.Second
)

const (
Expand Down
97 changes: 97 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package yamux

import "fmt"

type Error struct {
msg string
timeout, temporary bool
}

func (ye *Error) Error() string {
return ye.msg
}

func (ye *Error) Timeout() bool {
return ye.timeout
}

func (ye *Error) Temporary() bool {
return ye.temporary
}

type ErrorGoAway struct {
Remote bool
ErrorCode uint32
}

func (e *ErrorGoAway) Error() string {
if e.Remote {
return fmt.Sprintf("remote sent go away, code: %d", e.ErrorCode)
}
return fmt.Sprintf("sent go away, code: %d", e.ErrorCode)
}

func (e *ErrorGoAway) Timeout() bool {
return false
}

func (e *ErrorGoAway) Temporary() bool {
return false
}

func (e *ErrorGoAway) Is(target error) bool {
// to maintain compatibility with errors returned by previous versions
if e.Remote {
return target == ErrRemoteGoAway
} else {
return target == ErrSessionShutdown
}
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
ErrInvalidVersion = &Error{msg: "invalid protocol version"}

// ErrInvalidMsgType means we received a frame with an
// invalid message type
ErrInvalidMsgType = &Error{msg: "invalid msg type"}

// ErrSessionShutdown is used if there is a shutdown during
// an operation
ErrSessionShutdown = &Error{msg: "session shutdown"}

// ErrStreamsExhausted is returned if we have no more
// stream ids to issue
ErrStreamsExhausted = &Error{msg: "streams exhausted"}

// ErrDuplicateStream is used if a duplicate stream is
// opened inbound
ErrDuplicateStream = &Error{msg: "duplicate stream initiated"}

// ErrReceiveWindowExceeded indicates the window was exceeded
ErrRecvWindowExceeded = &Error{msg: "recv window exceeded"}

// ErrTimeout is used when we reach an IO deadline
ErrTimeout = &Error{msg: "i/o deadline reached", timeout: true, temporary: true}

// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = &Error{msg: "stream closed"}

// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = &Error{msg: "unexpected flag"}

// ErrRemoteGoAway is used when we get a go away from the other side
ErrRemoteGoAway = &Error{msg: "remote end is not accepting connections"}

// ErrStreamReset is sent if a stream is reset. This can happen
// if the backlog is exceeded, or if there was a remote GoAway.
ErrStreamReset = &Error{msg: "stream reset"}

// ErrConnectionWriteTimeout indicates that we hit the "safety valve"
// timeout writing to the underlying stream connection.
ErrConnectionWriteTimeout = &Error{msg: "connection write timeout", timeout: true}

// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true}
)
62 changes: 41 additions & 21 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yamux
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -46,12 +47,8 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
type Session struct {
rtt int64 // to be accessed atomically, in nanoseconds

// remoteGoAway indicates the remote side does
// not want futher connections. Must be first for alignment.
remoteGoAway int32

// localGoAway indicates that we should stop
// accepting futher connections. Must be first for alignment.
// accepting futher streams. Must be first for alignment.
localGoAway int32

// nextStreamID is the next stream we should
Expand Down Expand Up @@ -203,9 +200,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
if s.IsClosed() {
return nil, s.shutdownErr
}
if atomic.LoadInt32(&s.remoteGoAway) == 1 {
return nil, ErrRemoteGoAway
}

// Block if we have too many inflight SYNs
select {
Expand Down Expand Up @@ -284,24 +278,47 @@ func (s *Session) AcceptStream() (*Stream, error) {
}

// Close is used to close the session and all streams.
// Attempts to send a GoAway before closing the connection.
// Sends a GoAway before closing the connection.
func (s *Session) Close() error {
return s.CloseWithError(goAwayNormal)
}

// CloseWithError closes the session sending errCode in a goaway frame
func (s *Session) CloseWithError(errCode uint32) error {
return s.closeWithError(errCode, true)
}

func (s *Session) closeWithError(errCode uint32, sendGoAway bool) error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

if s.shutdown {
return nil
}
s.shutdown = true
if s.shutdownErr == nil {
s.shutdownErr = ErrSessionShutdown
s.shutdownErr = &ErrorGoAway{Remote: !sendGoAway, ErrorCode: errCode}
}
close(s.shutdownCh)
s.conn.Close()
s.stopKeepalive()
<-s.recvDoneCh

// wait for write loop
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
<-s.sendDoneCh

// send the goaway frame
if sendGoAway {
buf := pool.Get(headerSize)
hdr := s.goAway(errCode)
copy(buf, hdr[:])
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(buf) // Ignore the error. We are going to close the connection anyway
}
}
s.conn.Close()

// wait for read loop
<-s.recvDoneCh

s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
Expand All @@ -320,11 +337,11 @@ func (s *Session) exitErr(err error) {
s.shutdownErr = err
}
s.shutdownLock.Unlock()
s.Close()
s.closeWithError(0, false)
}

// GoAway can be used to prevent accepting further
// connections. It does not close the underlying conn.
// streams. It does not close the underlying conn.
func (s *Session) GoAway() error {
return s.sendMsg(s.goAway(goAwayNormal), nil, nil)
}
Expand Down Expand Up @@ -516,6 +533,12 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
// send is a long running goroutine that sends data
func (s *Session) send() {
if err := s.sendLoop(); err != nil {
if !s.IsClosed() && (errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) || strings.Contains(err.Error(), "reset") || strings.Contains(err.Error(), "broken pipe")) {
// if remote has closed the connection, wait for recv loop to exit
// unfortunately it is impossible to close the connection such that FIN is sent and not RST
<-s.recvDoneCh
return
}
s.exitErr(err)
}
}
Expand Down Expand Up @@ -781,18 +804,15 @@ func (s *Session) handleGoAway(hdr header) error {
code := hdr.Length()
switch code {
case goAwayNormal:
atomic.SwapInt32(&s.remoteGoAway, 1)
// Non error termination. Don't log.
case goAwayProtoErr:
s.logger.Printf("[ERR] yamux: received protocol error go away")
return fmt.Errorf("yamux protocol error")
case goAwayInternalErr:
s.logger.Printf("[ERR] yamux: received internal error go away")
return fmt.Errorf("remote yamux internal error")
default:
s.logger.Printf("[ERR] yamux: received unexpected go away")
return fmt.Errorf("unexpected go away received")
// application error code, let the application log
}
return nil
return &ErrorGoAway{ErrorCode: code, Remote: true}
}

// incomingStream is used to create a new incoming stream
Expand Down
52 changes: 44 additions & 8 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yamux
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -627,28 +628,63 @@ func TestSendData_Large(t *testing.T) {
}
}

func testTCPConns(t *testing.T) (*Session, *Session) {
ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
if err != nil {
t.Fatal(err)
}
serverConnCh := make(chan net.Conn, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
serverConnCh <- conn
}()

clientConn, err := net.DialTCP("tcp", nil, ln.Addr().(*net.TCPAddr))
if err != nil {
ln.Close()
t.Fatal(err)
return nil, nil
}

client, _ := Client(clientConn, testConf(), nil)
server, _ := Server(<-serverConnCh, testConf(), nil)
return client, server

}

func TestGoAway(t *testing.T) {
// This test is noisy.
conf := testConf()
conf.LogOutput = io.Discard

client, server := testClientServerConfig(conf)
client, server := testTCPConns(t)
defer client.Close()
defer server.Close()

if err := server.GoAway(); err != nil {
if err := server.CloseWithError(42); err != nil {
t.Fatalf("err: %v", err)
}

for i := 0; i < 100; i++ {
s, err := client.Open(context.Background())
switch err {
case nil:
s.Close()
case ErrRemoteGoAway:
if err != nil {
if !errors.Is(err, ErrRemoteGoAway) {
t.Fatal("expected error to be ErrRemoteGoAway, got", err)
}
errExpected := &ErrorGoAway{Remote: true, ErrorCode: 42}
errGot, ok := err.(*ErrorGoAway)
if !ok {
t.Fatalf("expected type *ErrorGoAway, got %T", err)
}
if *errGot != *errExpected {
t.Fatalf("invalid error, expected %v, got %v", errExpected, errGot)
}
return
default:
t.Fatalf("err: %v", err)
} else {
s.Close()
}
}
t.Fatalf("expected GoAway error")
Expand Down

0 comments on commit 273d2b4

Please sign in to comment.