Skip to content

Commit

Permalink
fix conn.Close() in function handleConnection (#180)
Browse files Browse the repository at this point in the history
Co-authored-by: husyhu <[email protected]>
  • Loading branch information
Husy and husy-dev authored Mar 29, 2023
1 parent 0daf8bf commit e9f340c
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package broker

import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -203,7 +204,10 @@ func (b *Broker) StartWebsocketListening() {
func (b *Broker) wsHandler(ws *websocket.Conn) {
// io.Copy(ws, ws)
ws.PayloadType = websocket.BinaryFrame
b.handleConnection(CLIENT, ws)
err:=b.handleConnection(CLIENT, ws)
if err!=nil{
ws.Close()
}
}

func (b *Broker) StartClientListening(Tls bool) {
Expand Down Expand Up @@ -254,7 +258,12 @@ func (b *Broker) StartClientListening(Tls bool) {
}

tmpDelay = ACCEPT_MIN_SLEEP
go b.handleConnection(CLIENT, conn)
go func(){
err :=b.handleConnection(CLIENT, conn)
if err!=nil{
conn.Close()
}
}()
}
}

Expand Down Expand Up @@ -291,7 +300,12 @@ func (b *Broker) StartClusterListening() {
}
tmpDelay = ACCEPT_MIN_SLEEP

go b.handleConnection(ROUTER, conn)
go func(){
err :=b.handleConnection(ROUTER, conn)
if err!=nil{
conn.Close()
}
}()
}
}

Expand All @@ -307,22 +321,18 @@ func (b *Broker) DisConnClientByClientId(clientId string) {
conn.Close()
}

func (b *Broker) handleConnection(typ int, conn net.Conn) {
func (b *Broker) handleConnection(typ int, conn net.Conn) error{
//process connect packet
packet, err := packets.ReadPacket(conn)
if err != nil {
log.Error("read connect packet error", zap.Error(err))
conn.Close()
return
return errors.New(fmt.Sprintln("read connect packet error:%v",err))
}
if packet == nil {
log.Error("received nil packet")
return
return errors.New("received nil packet")
}
msg, ok := packet.(*packets.ConnectPacket)
if !ok {
log.Error("received msg that was not Connect")
return
return errors.New("received msg that was not Connect")
}

log.Info("read connect from ", getAdditionalLogFields(msg.ClientIdentifier, conn)...)
Expand All @@ -332,29 +342,22 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) {
connack.ReturnCode = msg.Validate()

if connack.ReturnCode != packets.Accepted {
func() {
defer conn.Close()
if err := connack.Write(conn); err != nil {
log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...)
}
}()
return
if err := connack.Write(conn); err != nil {
return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn))
}
return errors.New(fmt.Sprintln("connect packet validate failed with connack.ReturnCode%v",connack.ReturnCode))
}

if typ == CLIENT && !b.CheckConnectAuth(msg.ClientIdentifier, msg.Username, string(msg.Password)) {
connack.ReturnCode = packets.ErrRefusedNotAuthorised
func() {
defer conn.Close()
if err := connack.Write(conn); err != nil {
log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...)
}
}()
return
if err := connack.Write(conn); err != nil {
return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn))
}
return errors.New(fmt.Sprintln("connect packet CheckConnectAuth failed with connack.ReturnCode%v",connack.ReturnCode))
}

if err := connack.Write(conn); err != nil {
log.Error("send connack error", getAdditionalLogFields(msg.ClientIdentifier, conn, zap.Error(err))...)
return
return errors.New(fmt.Sprintln("send connack error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn))
}

willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
Expand Down Expand Up @@ -385,8 +388,7 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) {
c.init()

if err := b.getSession(c, msg, connack); err != nil {
log.Error("get session error", getAdditionalLogFields(c.info.clientID, conn, zap.Error(err))...)
return
return errors.New(fmt.Sprintln("get session error:%v,clientID:%v,conn:%v",err,msg.ClientIdentifier,conn))
}

cid := c.info.clientID
Expand Down Expand Up @@ -426,6 +428,7 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) {
}

c.readLoop()
return nil
}

func (b *Broker) ConnectToDiscovery() {
Expand Down

0 comments on commit e9f340c

Please sign in to comment.