From e9f340c38f3b762fc04bc2990dec14e73f8c3b86 Mon Sep 17 00:00:00 2001 From: Husy <350980310@qq.com> Date: Wed, 29 Mar 2023 09:45:12 +0800 Subject: [PATCH] fix conn.Close() in function handleConnection (#180) Co-authored-by: husyhu --- broker/broker.go | 61 +++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/broker/broker.go b/broker/broker.go index 53246af0..a32dff52 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -2,6 +2,7 @@ package broker import ( "crypto/tls" + "errors" "fmt" "net" "net/http" @@ -203,7 +204,10 @@ func (b *Broker) StartWebsocketListening() { func (b *Broker) wsHandler(ws *websocket.Conn) { // io.Copy(ws, ws) ws.PayloadType = websocket.BinaryFrame - b.handleConnection(CLIENT, ws) + err:=b.handleConnection(CLIENT, ws) + if err!=nil{ + ws.Close() + } } func (b *Broker) StartClientListening(Tls bool) { @@ -254,7 +258,12 @@ func (b *Broker) StartClientListening(Tls bool) { } tmpDelay = ACCEPT_MIN_SLEEP - go b.handleConnection(CLIENT, conn) + go func(){ + err :=b.handleConnection(CLIENT, conn) + if err!=nil{ + conn.Close() + } + }() } } @@ -291,7 +300,12 @@ func (b *Broker) StartClusterListening() { } tmpDelay = ACCEPT_MIN_SLEEP - go b.handleConnection(ROUTER, conn) + go func(){ + err :=b.handleConnection(ROUTER, conn) + if err!=nil{ + conn.Close() + } + }() } } @@ -307,22 +321,18 @@ func (b *Broker) DisConnClientByClientId(clientId string) { conn.Close() } -func (b *Broker) handleConnection(typ int, conn net.Conn) { +func (b *Broker) handleConnection(typ int, conn net.Conn) error{ //process connect packet packet, err := packets.ReadPacket(conn) if err != nil { - log.Error("read connect packet error", zap.Error(err)) - conn.Close() - return + return errors.New(fmt.Sprintln("read connect packet error:%v",err)) } if packet == nil { - log.Error("received nil packet") - return + return errors.New("received nil packet") } msg, ok := packet.(*packets.ConnectPacket) if !ok { - log.Error("received msg that was not Connect") - return + return errors.New("received msg that was not Connect") } log.Info("read connect from ", getAdditionalLogFields(msg.ClientIdentifier, conn)...) @@ -332,29 +342,22 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { connack.ReturnCode = msg.Validate() if connack.ReturnCode != packets.Accepted { - func() { - defer conn.Close() - if err := connack.Write(conn); err != nil { - log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...) - } - }() - return + if err := connack.Write(conn); err != nil { + return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn)) + } + return errors.New(fmt.Sprintln("connect packet validate failed with connack.ReturnCode%v",connack.ReturnCode)) } if typ == CLIENT && !b.CheckConnectAuth(msg.ClientIdentifier, msg.Username, string(msg.Password)) { connack.ReturnCode = packets.ErrRefusedNotAuthorised - func() { - defer conn.Close() - if err := connack.Write(conn); err != nil { - log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...) - } - }() - return + if err := connack.Write(conn); err != nil { + return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn)) + } + return errors.New(fmt.Sprintln("connect packet CheckConnectAuth failed with connack.ReturnCode%v",connack.ReturnCode)) } if err := connack.Write(conn); err != nil { - log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...) - return + return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn)) } willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) @@ -385,8 +388,7 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { c.init() if err := b.getSession(c, msg, connack); err != nil { - log.Error("get session error", getAdditionalLogFields(c.info.clientID, conn, zap.Error(err))...) - return + return errors.New(fmt.Sprintln("get session error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn)) } cid := c.info.clientID @@ -426,6 +428,7 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { } c.readLoop() + return nil } func (b *Broker) ConnectToDiscovery() {