Skip to content

Commit

Permalink
Fix unordered state handler call on reconnect failure (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Apr 19, 2023
1 parent 0d2d878 commit e787a90
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 14 deletions.
9 changes: 6 additions & 3 deletions internal/filteredpipe/close.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ import (
func DetectAndClosePipe(h0, h1 func([]byte) bool) (io.ReadWriteCloser, io.ReadWriteCloser) {
ch0 := make(chan []byte, 1000)
ch1 := make(chan []byte, 1000)
closed, fnClose := mewCloseCh()
return &detectAndCloseConn{
baseFilterConn: &baseFilterConn{
rCh: ch0,
wCh: ch1,
handler: h0,
closed: make(chan struct{}),
closed: closed,
fnClose: fnClose,
},
}, &detectAndCloseConn{
baseFilterConn: &baseFilterConn{
rCh: ch1,
wCh: ch0,
handler: h1,
closed: make(chan struct{}),
closed: closed,
fnClose: fnClose,
},
}
}
Expand All @@ -46,7 +49,7 @@ type detectAndCloseConn struct {

func (c *detectAndCloseConn) Write(data []byte) (n int, err error) {
if c.handler(data) {
c.closeOnce.Do(func() { close(c.closed) })
c.fnClose()
return 0, io.ErrClosedPipe
}
select {
Expand Down
7 changes: 5 additions & 2 deletions internal/filteredpipe/drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@ import (
func DetectAndDropPipe(h0, h1 func([]byte) bool) (io.ReadWriteCloser, io.ReadWriteCloser) {
ch0 := make(chan []byte, 1000)
ch1 := make(chan []byte, 1000)
closed, fnClose := mewCloseCh()
return &detectAndDropConn{
baseFilterConn: &baseFilterConn{
rCh: ch0,
wCh: ch1,
handler: h0,
closed: make(chan struct{}),
closed: closed,
fnClose: fnClose,
},
dropping: make(chan struct{}),
}, &detectAndDropConn{
baseFilterConn: &baseFilterConn{
rCh: ch1,
wCh: ch0,
handler: h1,
closed: make(chan struct{}),
closed: closed,
fnClose: fnClose,
},
dropping: make(chan struct{}),
}
Expand Down
15 changes: 13 additions & 2 deletions internal/filteredpipe/pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ type baseFilterConn struct {
rCh chan []byte
wCh chan []byte
handler func([]byte) bool
closed chan struct{}
closed <-chan struct{}
fnClose func()
closeOnce sync.Once
remain io.Reader
}
Expand Down Expand Up @@ -77,6 +78,16 @@ func (c *baseFilterConn) Read(data []byte) (n int, err error) {
}

func (c *baseFilterConn) Close() error {
c.closeOnce.Do(func() { close(c.closed) })
c.fnClose()
return nil
}

func mewCloseCh() (<-chan struct{}, func()) {
ch := make(chan struct{})
var once sync.Once
return ch, func() {
once.Do(func() {
close(ch)
})
}
}
20 changes: 13 additions & 7 deletions reconnclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,7 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...
if baseCli, err := c.dialer.DialContext(ctx); err == nil {
c.RetryClient.SetClient(ctx, baseCli)

var ctxConnect context.Context
var cancelConnect func()
if c.options.Timeout == 0 {
ctxConnect, cancelConnect = ctx, func() {}
} else {
ctxConnect, cancelConnect = context.WithTimeout(ctx, c.options.Timeout)
}
ctxConnect, cancelConnect := c.options.timeoutContext(ctx)

if sessionPresent, err := c.RetryClient.Connect(ctxConnect, clientID, opts...); err == nil {
cancelConnect()
Expand Down Expand Up @@ -147,6 +141,11 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...
errConnect.Store(err) // Hold first connect error excepting context cancel.
}
cancelConnect()

// Close baseCli to avoid unordered state callback
baseCli.Close()
// baseCli.Done() should be returned immediately if no incoming message callback is not blocked
<-baseCli.Done()
} else if err != ctx.Err() {
errDial.Store(err) // Hold first dial error excepting context cancel.
}
Expand Down Expand Up @@ -206,6 +205,13 @@ type ReconnectOptions struct {
AlwaysResubscribe bool
}

func (c *ReconnectOptions) timeoutContext(ctx context.Context) (context.Context, func()) {
if c.Timeout == 0 {
return ctx, func() {}
}
return context.WithTimeout(ctx, c.Timeout)
}

// ReconnectOption sets option for Connect.
type ReconnectOption func(*ReconnectOptions) error

Expand Down
95 changes: 95 additions & 0 deletions reconnclient_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,98 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) {
})
}
}

func TestIntegration_ReconnectClient_WithConnStateHandler(t *testing.T) {
for name, url := range urls {
url := url
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

chState := make(chan ConnState, 1)
var dialCnt int32

cli, err := NewReconnectClient(
DialerFunc(func(ctx context.Context) (*BaseClient, error) {
cli, err := DialContext(ctx, url,
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
WithConnStateHandler(func(state ConnState, err error) {
chState <- state
}),
)
if err != nil {
return nil, err
}
cnt := atomic.AddInt32(&dialCnt, 1)
ca, cb := filteredpipe.DetectAndClosePipe(
newFilterBase(func(msg []byte) bool {
if cnt == 2 && msg[0]&0xf0 == 0x20 {
time.Sleep(150 * time.Millisecond)
return true
}
return false
}),
newFilterBase(func(msg []byte) bool {
if cnt == 1 && msg[0]&0xf0 == 0x30 {
return true
}
return false
}),
)
filteredpipe.Connect(ca, cli.Transport)
cli.Transport = cb
return cli, nil
}),
WithRetryClient(&RetryClient{
ResponseTimeout: 100 * time.Millisecond,
}),
WithPingInterval(time.Second),
WithTimeout(100*time.Millisecond),
WithReconnectWait(10*time.Millisecond, 10*time.Millisecond),
)
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if _, err = cli.Connect(
ctx,
"ReconnectClientErrDuringReconnect"+name,
WithKeepAlive(10),
WithCleanSession(true),
); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Publish(ctx, &Message{
Topic: "error_during_reconnect",
QoS: QoS1,
Payload: []byte{},
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

assertStateChange := func(expected ConnState) {
select {
case <-ctx.Done():
t.Error("Timeout")
case s := <-chState:
if s != expected {
t.Errorf("Expected %s, got %s", expected, s)
}
}
}
assertStateChange(StateActive)
assertStateChange(StateClosed)
assertStateChange(StateClosed)
assertStateChange(StateActive)

select {
case <-time.After(300 * time.Millisecond):
case s := <-chState:
t.Errorf("Unexpected state change to %s", s)
}

cli.Disconnect(ctx)
assertStateChange(StateDisconnected)
})
}
}

0 comments on commit e787a90

Please sign in to comment.