diff --git a/const.go b/const.go index 3ecba41..08199af 100644 --- a/const.go +++ b/const.go @@ -45,7 +45,7 @@ func (e *GoAwayError) Temporary() bool { func (e *GoAwayError) Is(target error) bool { // to maintain compatibility with errors returned by previous versions - if e.Remote && target == ErrRemoteGoAwayNormal { + if e.Remote && target == ErrRemoteGoAway { return true } else if !e.Remote && target == ErrSessionShutdown { return true @@ -114,8 +114,9 @@ var ( // ErrUnexpectedFlag is set when we get an unexpected flag ErrUnexpectedFlag = &Error{msg: "unexpected flag"} - // ErrRemoteGoAwayNormal is used when we get a go away from the other side - ErrRemoteGoAwayNormal = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} + // ErrRemoteGoAway is used when we get a go away from the other side with error code + // goAwayNormal(0). + ErrRemoteGoAway = &GoAwayError{Remote: true, ErrorCode: goAwayNormal} // ErrStreamReset is sent if a stream is reset. This can happen // if the backlog is exceeded, or if there was a remote GoAway. diff --git a/session.go b/session.go index 06d20fa..6fb6731 100644 --- a/session.go +++ b/session.go @@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{} type Session struct { rtt int64 // to be accessed atomically, in nanoseconds - // remoteGoAwayNormal indicates the remote side does - // not want futher connections. Must be first for alignment. - remoteGoAwayNormal int32 - // localGoAway indicates that we should stop // accepting futher connections. Must be first for alignment. localGoAway int32 @@ -205,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { if s.IsClosed() { return nil, s.shutdownErr } - if atomic.LoadInt32(&s.remoteGoAwayNormal) == 1 { - return nil, ErrRemoteGoAwayNormal - } // Block if we have too many inflight SYNs select { @@ -535,8 +528,14 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err // send is a long running goroutine that sends data func (s *Session) send() { if err := s.sendLoop(); err != nil { - // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code - // received in a GoAway frame received just before the TCP RST that closed the sendLoop + // If we are shutting down because remote closed the connection, prefer the recvLoop error + // over the sendLoop error. The receive loop might have error code received in a GoAway frame, + // which was received just before the TCP RST that closed the sendLoop. + // + // If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop. + // We hold the shutdownLock, close the connection, and wait for the receive loop to finish and + // use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close + // but the sendLoop does. s.shutdownLock.Lock() if s.shutdownErr == nil { s.conn.Close() @@ -815,10 +814,7 @@ func (s *Session) handleGoAway(hdr header) error { code := hdr.Length() switch code { case goAwayNormal: - atomic.SwapInt32(&s.remoteGoAwayNormal, 1) - // Don't close connection on normal go away. Let the existing streams - // complete gracefully. - return nil + return ErrRemoteGoAway case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") case goAwayInternalErr: diff --git a/session_test.go b/session_test.go index dc6c3f0..6d3bce0 100644 --- a/session_test.go +++ b/session_test.go @@ -648,15 +648,16 @@ func TestGoAway(t *testing.T) { for i := 0; i < 100; i++ { s, err := client.Open(context.Background()) - switch err { - case nil: + if err == nil { s.Close() - case ErrRemoteGoAwayNormal: + time.Sleep(50 * time.Millisecond) + continue + } + if err != ErrRemoteGoAway { + t.Fatalf("expected %s, got %s", ErrRemoteGoAway, err) + } else { return - default: - t.Fatalf("err: %v", err) } - time.Sleep(50 * time.Millisecond) } t.Fatalf("expected GoAway error") } @@ -1578,7 +1579,7 @@ func TestStreamResetWithError(t *testing.T) { defer server.Close() wc := new(sync.WaitGroup) - wc.Add(2) + wc.Add(1) go func() { defer wc.Done() stream, err := server.AcceptStream() @@ -1589,7 +1590,7 @@ func TestStreamResetWithError(t *testing.T) { se := &StreamError{} _, err = io.ReadAll(stream) if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) return } expected := &StreamError{Remote: true, ErrorCode: 42} @@ -1601,24 +1602,19 @@ func TestStreamResetWithError(t *testing.T) { t.Error(err) } - go func() { - defer wc.Done() - - se := &StreamError{} - _, err := io.ReadAll(stream) - if !errors.As(err, &se) { - t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) - return - } - expected := &StreamError{Remote: false, ErrorCode: 42} - assert.Equal(t, se, expected) - }() - time.Sleep(1 * time.Second) err = stream.ResetWithError(42) if err != nil { t.Fatal(err) } + se := &StreamError{} + _, err = io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("expected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: false, ErrorCode: 42} + assert.Equal(t, se, expected) wc.Wait() } diff --git a/stream.go b/stream.go index 0835165..15a8b56 100644 --- a/stream.go +++ b/stream.go @@ -395,7 +395,7 @@ func (s *Stream) cleanup() { // processFlags is used to update the state of the stream // based on set flags, if any. Lock must be held -func (s *Stream) processFlags(flags uint16, hdr header) { +func (s *Stream) processFlags(hdr header, flags uint16) { // Close the stream without holding the state lock var closeStream bool defer func() { @@ -459,7 +459,7 @@ func (s *Stream) notifyWaiting() { // incrSendWindow updates the size of our send window func (s *Stream) incrSendWindow(hdr header, flags uint16) { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) @@ -467,7 +467,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) { // readData is used to handle a data frame func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - s.processFlags(flags, hdr) + s.processFlags(hdr, flags) // Check that our recv window is not exceeded length := hdr.Length()