diff --git a/Makefile b/Makefile index 618db45ce..7f4fb6346 100644 --- a/Makefile +++ b/Makefile @@ -48,8 +48,15 @@ else TAGPARAM=--tags $(TAGS) endif +DEBUG ?= +ifeq ($(DEBUG),1) + DEBUGFLAGS=-gcflags=all="-N -l" +else + DEBUGFLAGS= +endif + receptor: $(shell find pkg -type f -name '*.go') ./cmd/receptor-cl/receptor.go - CGO_ENABLED=0 go build -o receptor -ldflags "-X 'github.com/ansible/receptor/internal/version.Version=$(APPVER)'" $(TAGPARAM) ./cmd/receptor-cl + CGO_ENABLED=0 go build -o receptor $(DEBUGFLAGS) -ldflags "-X 'github.com/ansible/receptor/internal/version.Version=$(APPVER)'" $(TAGPARAM) ./cmd/receptor-cl lint: @golint cmd/... pkg/... example/... diff --git a/pkg/controlsvc/connect.go b/pkg/controlsvc/connect.go index 6c3884880..a1b255e8a 100644 --- a/pkg/controlsvc/connect.go +++ b/pkg/controlsvc/connect.go @@ -1,6 +1,7 @@ package controlsvc import ( + "context" "fmt" "strings" @@ -73,7 +74,7 @@ func (t *connectCommandType) InitFromJSON(config map[string]interface{}) (Contro return c, nil } -func (c *connectCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *connectCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { tlscfg, err := nc.GetClientTLSConfig(c.tlsConfigName, c.targetNode, "receptor") if err != nil { return nil, err diff --git a/pkg/controlsvc/controlsvc.go b/pkg/controlsvc/controlsvc.go index c950b09dc..f73a2a9d1 100644 --- a/pkg/controlsvc/controlsvc.go +++ b/pkg/controlsvc/controlsvc.go @@ -125,12 +125,14 @@ func (s *Server) AddControlFunc(name string, cType ControlCommandType) error { // RunControlSession runs the server protocol on the given connection. func (s *Server) RunControlSession(conn net.Conn) { - logger.Info("Client connected to control service\n") + logger.Info("Client connected to control service %s\n", conn.RemoteAddr().String()) defer func() { - logger.Info("Client disconnected from control service\n") - err := conn.Close() - if err != nil { - logger.Error("Error closing connection: %s\n", err) + logger.Info("Client disconnected from control service %s\n", conn.RemoteAddr().String()) + if conn != nil { + err := conn.Close() + if err != nil { + logger.Error("Error closing connection: %s\n", err) + } } }() _, err := conn.Write([]byte(fmt.Sprintf("Receptor Control, node %s\n", s.nc.NodeID()))) @@ -224,7 +226,9 @@ func (s *Server) RunControlSession(conn net.Conn) { cc, err = ct.InitFromJSON(jsonData) } if err == nil { - cfr, err = cc.ControlFunc(s.nc, cfo) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfr, err = cc.ControlFunc(ctx, s.nc, cfo) } if err != nil { logger.Error(err.Error()) @@ -331,6 +335,11 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls. if ctx.Err() != nil { return } + if err != nil { + if strings.HasSuffix(err.Error(), "normal close") { + continue + } + } if err != nil { logger.Error("Error accepting connection: %s. Closing listener.\n", err) _ = listener.Close() @@ -338,27 +347,25 @@ func (s *Server) RunControlSvc(ctx context.Context, service string, tlscfg *tls. return } go func() { + defer conn.Close() tlsConn, ok := conn.(*tls.Conn) if ok { // Explicitly run server TLS handshake so we can deal with timeout and errors here err = conn.SetDeadline(time.Now().Add(10 * time.Second)) if err != nil { logger.Error("Error setting timeout: %s. Closing socket.\n", err) - _ = conn.Close() return } err = tlsConn.Handshake() if err != nil { logger.Error("TLS handshake error: %s. Closing socket.\n", err) - _ = conn.Close() return } err = conn.SetDeadline(time.Time{}) if err != nil { logger.Error("Error clearing timeout: %s. Closing socket.\n", err) - _ = conn.Close() return } diff --git a/pkg/controlsvc/interfaces.go b/pkg/controlsvc/interfaces.go index 3ae181bd2..9975a9a72 100644 --- a/pkg/controlsvc/interfaces.go +++ b/pkg/controlsvc/interfaces.go @@ -1,6 +1,7 @@ package controlsvc import ( + "context" "io" "net" @@ -15,7 +16,7 @@ type ControlCommandType interface { // ControlCommand is an instance of a command that is being run from the control service. type ControlCommand interface { - ControlFunc(*netceptor.Netceptor, ControlFuncOperations) (map[string]interface{}, error) + ControlFunc(context.Context, *netceptor.Netceptor, ControlFuncOperations) (map[string]interface{}, error) } // ControlFuncOperations provides callbacks for control services to take actions. diff --git a/pkg/controlsvc/ping.go b/pkg/controlsvc/ping.go index 36d25f507..bac443202 100644 --- a/pkg/controlsvc/ping.go +++ b/pkg/controlsvc/ping.go @@ -55,17 +55,16 @@ func ping(nc *netceptor.Netceptor, target string, hopsToLive byte) (time.Duratio _ = pc.Close() }() pc.SetHopsToLive(hopsToLive) - unrCh := pc.SubscribeUnreachable() + doneChan := make(chan struct{}) + unrCh := pc.SubscribeUnreachable(doneChan) + defer close(doneChan) type errorResult struct { err error fromNode string } errorChan := make(chan errorResult) go func() { - select { - case <-ctx.Done(): - return - case msg := <-unrCh: + for msg := range unrCh { errorChan <- errorResult{ err: fmt.Errorf(msg.Problem), fromNode: msg.ReceivedFromNode, @@ -111,7 +110,7 @@ func ping(nc *netceptor.Netceptor, target string, hopsToLive byte) (time.Duratio } } -func (c *pingCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *pingCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { pingTime, pingRemote, err := ping(nc, c.target, nc.MaxForwardingHops()) cfr := make(map[string]interface{}) if err == nil { diff --git a/pkg/controlsvc/reload.go b/pkg/controlsvc/reload.go index 690e0cc6a..a7b233c4e 100644 --- a/pkg/controlsvc/reload.go +++ b/pkg/controlsvc/reload.go @@ -1,6 +1,7 @@ package controlsvc import ( + "context" "fmt" "io/ioutil" "strings" @@ -157,7 +158,7 @@ func handleError(err error, errorcode int) (map[string]interface{}, error) { return cfr, nil } -func (c *reloadCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *reloadCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { // Reload command stops all backends, and re-runs the ParseAndRun() on the // initial config file logger.Debug("Reloading") diff --git a/pkg/controlsvc/status.go b/pkg/controlsvc/status.go index 86741c010..271c033d3 100644 --- a/pkg/controlsvc/status.go +++ b/pkg/controlsvc/status.go @@ -1,6 +1,7 @@ package controlsvc import ( + "context" "fmt" "github.com/ansible/receptor/internal/version" @@ -46,7 +47,7 @@ func (t *statusCommandType) InitFromJSON(config map[string]interface{}) (Control return c, nil } -func (c *statusCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *statusCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { status := nc.Status() statusGetters := make(map[string]func() interface{}) statusGetters["Version"] = func() interface{} { return version.Version } diff --git a/pkg/controlsvc/traceroute.go b/pkg/controlsvc/traceroute.go index 3c46300e6..183087ee8 100644 --- a/pkg/controlsvc/traceroute.go +++ b/pkg/controlsvc/traceroute.go @@ -1,6 +1,7 @@ package controlsvc import ( + "context" "fmt" "strconv" @@ -41,7 +42,7 @@ func (t *tracerouteCommandType) InitFromJSON(config map[string]interface{}) (Con return c, nil } -func (c *tracerouteCommand) ControlFunc(nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { +func (c *tracerouteCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo ControlFuncOperations) (map[string]interface{}, error) { cfr := make(map[string]interface{}) for i := 0; i <= int(nc.MaxForwardingHops()); i++ { thisResult := make(map[string]interface{}) diff --git a/pkg/netceptor/conn.go b/pkg/netceptor/conn.go index 59f04fef4..05fdb29c6 100644 --- a/pkg/netceptor/conn.go +++ b/pkg/netceptor/conn.go @@ -16,7 +16,7 @@ import ( "sync" "time" - "github.com/ansible/receptor/pkg/utils" + "github.com/ansible/receptor/pkg/logger" "github.com/lucas-clemente/quic-go" ) @@ -197,7 +197,7 @@ func (li *Listener) acceptLoop() { return } doneChan := make(chan struct{}, 1) - cctx, ccancel := utils.ContextWithCancelWithErr(li.s.context) + cctx, ccancel := context.WithCancel(li.s.context) conn := &Conn{ s: li.s, pc: li.pc, @@ -296,7 +296,7 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string _ = pc.Close() }) } - cctx, ccancel := utils.ContextWithCancelWithErr(ctx) + cctx, ccancel := context.WithCancel(ctx) go func() { select { case <-okChan: @@ -370,18 +370,18 @@ func (s *Netceptor) DialContext(ctx context.Context, node string, service string // monitorUnreachable receives unreachable messages from the underlying PacketConn, and ends the connection // if the remote service has gone away. -func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, cancel utils.CancelWithErrFunc) { - msgCh := pc.SubscribeUnreachable() - for { - select { - case <-pc.context.Done(): - return - case <-doneChan: - return - case msg := <-msgCh: - if msg.Problem == ProblemServiceUnknown && msg.ToNode == remoteAddr.node && msg.ToService == remoteAddr.service { - cancel(fmt.Errorf("remote service unreachable")) - } +func monitorUnreachable(pc *PacketConn, doneChan chan struct{}, remoteAddr Addr, cancel context.CancelFunc) { + msgCh := pc.SubscribeUnreachable(doneChan) + if msgCh == nil { + cancel() + + return + } + // read from channel until closed + for msg := range msgCh { + if msg.Problem == ProblemServiceUnknown && msg.ToNode == remoteAddr.node && msg.ToService == remoteAddr.service { + logger.Error("remote service unreachable") + cancel() } } } @@ -410,6 +410,16 @@ func (c *Conn) Close() error { return c.qs.Close() } +func (c *Conn) CloseConnection() error { + c.pc.cancel() + c.doneOnce.Do(func() { + close(c.doneChan) + }) + logger.Debug("closing connection from service %s to %s", c.pc.localService, c.RemoteAddr().String()) + + return c.qc.CloseWithError(0, "normal close") +} + // LocalAddr returns the local address of this connection. func (c *Conn) LocalAddr() net.Addr { return c.qc.LocalAddr() diff --git a/pkg/netceptor/netceptor.go b/pkg/netceptor/netceptor.go index 61e45d1d9..5e70a9c59 100644 --- a/pkg/netceptor/netceptor.go +++ b/pkg/netceptor/netceptor.go @@ -1622,7 +1622,13 @@ func (s *Netceptor) sendInitialConnectMessage(ci *connInfo, initDoneChan chan bo return } logger.Debug("Sending initial connection message\n") - ci.WriteChan <- ri + select { + case ci.WriteChan <- ri: + case <-ci.Context.Done(): + return + case <-initDoneChan: + return + } count++ if count > 10 { logger.Warning("Giving up on connection initialization\n") @@ -1643,15 +1649,18 @@ func (s *Netceptor) sendInitialConnectMessage(ci *connInfo, initDoneChan chan bo } } -func (s *Netceptor) sendRejectMessage(writeChan chan []byte) { +func (s *Netceptor) sendRejectMessage(ci *connInfo) { rejMsg, err := s.translateStructToNetwork(MsgTypeReject, make([]string, 0)) if err != nil { - writeChan <- rejMsg + select { + case <-ci.Context.Done(): + case ci.WriteChan <- rejMsg: + } } } func (s *Netceptor) sendAndLogConnectionRejection(remoteNodeID string, ci *connInfo, reason string) error { - s.sendRejectMessage(ci.WriteChan) + s.sendRejectMessage(ci) return fmt.Errorf("rejected connection with node %s because %s", remoteNodeID, reason) } diff --git a/pkg/netceptor/netceptor_test.go b/pkg/netceptor/netceptor_test.go index 583bd334b..01db077bc 100644 --- a/pkg/netceptor/netceptor_test.go +++ b/pkg/netceptor/netceptor_test.go @@ -556,18 +556,19 @@ func TestFirewalling(t *testing.T) { } // Subscribe for unreachable messages - unreach2chan := pc2.SubscribeUnreachable() + doneChan := make(chan struct{}) + unreach2chan := pc2.SubscribeUnreachable(doneChan) // Save received unreachable messages to a variable var lastUnreachMsg *UnreachableNotification go func() { - for { - select { - case <-timeout.Done(): - return - case unreach := <-unreach2chan: - lastUnreachMsg = &unreach - } + <-timeout.Done() + close(doneChan) + }() + go func() { + for unreach := range unreach2chan { + unreach := unreach + lastUnreachMsg = &unreach } }() @@ -715,18 +716,19 @@ func TestAllowedPeers(t *testing.T) { } // Subscribe for unreachable messages - unreach2chan := pc2.SubscribeUnreachable() + doneChan := make(chan struct{}) + unreach2chan := pc2.SubscribeUnreachable(doneChan) // Save received unreachable messages to a variable var lastUnreachMsg *UnreachableNotification go func() { - for { - select { - case <-timeout.Done(): - return - case unreach := <-unreach2chan: - lastUnreachMsg = &unreach - } + <-timeout.Done() + close(doneChan) + }() + go func() { + for unreach := range unreach2chan { + unreach := unreach + lastUnreachMsg = &unreach } }() diff --git a/pkg/netceptor/packetconn.go b/pkg/netceptor/packetconn.go index eb371a148..08722d9de 100644 --- a/pkg/netceptor/packetconn.go +++ b/pkg/netceptor/packetconn.go @@ -12,18 +12,17 @@ import ( // PacketConn implements the net.PacketConn interface via the Receptor network. type PacketConn struct { - s *Netceptor - localService string - recvChan chan *MessageData - readDeadline time.Time - advertise bool - adTags map[string]string - connType byte - hopsToLive byte - unreachableMsgChan chan interface{} - unreachableSubs *utils.Broker - context context.Context - cancel context.CancelFunc + s *Netceptor + localService string + recvChan chan *MessageData + readDeadline time.Time + advertise bool + adTags map[string]string + connType byte + hopsToLive byte + unreachableSubs *utils.Broker + context context.Context + cancel context.CancelFunc } // ListenPacket returns a datagram connection compatible with Go's net.PacketConn. @@ -76,50 +75,59 @@ func (s *Netceptor) ListenPacketAndAdvertise(service string, tags map[string]str func (pc *PacketConn) startUnreachable() { pc.context, pc.cancel = context.WithCancel(pc.s.context) pc.unreachableSubs = utils.NewBroker(pc.context, reflect.TypeOf(UnreachableNotification{})) - pc.unreachableMsgChan = pc.s.unreachableBroker.Subscribe() + iChan := pc.s.unreachableBroker.Subscribe() go func() { - for { - select { - case <-pc.context.Done(): - return - case msgIf := <-pc.unreachableMsgChan: - msg, ok := msgIf.(UnreachableNotification) - if !ok { - continue - } - FromNode := msg.FromNode - FromService := msg.FromService - if FromNode == pc.s.nodeID && FromService == pc.localService { - _ = pc.unreachableSubs.Publish(msg) - } + <-pc.context.Done() + pc.s.unreachableBroker.Unsubscribe(iChan) + }() + go func() { + for msgIf := range iChan { + msg, ok := msgIf.(UnreachableNotification) + if !ok { + continue + } + FromNode := msg.FromNode + FromService := msg.FromService + if FromNode == pc.s.nodeID && FromService == pc.localService { + _ = pc.unreachableSubs.Publish(msg) } } }() } // SubscribeUnreachable subscribes for unreachable messages relevant to this PacketConn. -func (pc *PacketConn) SubscribeUnreachable() chan UnreachableNotification { +func (pc *PacketConn) SubscribeUnreachable(doneChan chan struct{}) chan UnreachableNotification { iChan := pc.unreachableSubs.Subscribe() + if iChan == nil { + return nil + } uChan := make(chan UnreachableNotification) + // goroutine 1 + // if doneChan is selected, this will unsubscribe the channel, which should + // eventually close out the go routine 2 + go func() { + select { + case <-doneChan: + pc.unreachableSubs.Unsubscribe(iChan) + case <-pc.context.Done(): + } + }() + // goroutine 2 + // this will exit when either the broker closes iChan, or the broker + // returns via pc.context.Done() go func() { for { - select { - case msgIf, ok := <-iChan: - if !ok { - close(uChan) - - return - } - msg, ok := msgIf.(UnreachableNotification) - if !ok { - continue - } - uChan <- msg - case <-pc.context.Done(): + msgIf, ok := <-iChan + if !ok { close(uChan) return } + msg, ok := msgIf.(UnreachableNotification) + if !ok { + continue + } + uChan <- msg } }() @@ -130,7 +138,11 @@ func (pc *PacketConn) SubscribeUnreachable() chan UnreachableNotification { func (pc *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { var m *MessageData if pc.readDeadline.IsZero() { - m = <-pc.recvChan + select { + case m = <-pc.recvChan: + case <-pc.context.Done(): + return 0, nil, fmt.Errorf("connection context closed") + } } else { select { case m = <-pc.recvChan: diff --git a/pkg/utils/broker.go b/pkg/utils/broker.go index b4bf6630f..bae858ae9 100644 --- a/pkg/utils/broker.go +++ b/pkg/utils/broker.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "sync" ) // Broker code adapted from https://stackoverflow.com/questions/36417199/how-to-broadcast-message-using-channel @@ -23,7 +24,7 @@ func NewBroker(ctx context.Context, msgType reflect.Type) *Broker { b := &Broker{ ctx: ctx, msgType: msgType, - publishCh: make(chan interface{}, 1), + publishCh: make(chan interface{}), subCh: make(chan chan interface{}), unsubCh: make(chan chan interface{}), } @@ -47,40 +48,41 @@ func (b *Broker) start() { subs[msgCh] = struct{}{} case msgCh := <-b.unsubCh: delete(subs, msgCh) + close(msgCh) case msg := <-b.publishCh: + wg := sync.WaitGroup{} for msgCh := range subs { + wg.Add(1) go func(msgCh chan interface{}) { + defer wg.Done() select { case msgCh <- msg: case <-b.ctx.Done(): } }(msgCh) } + wg.Wait() } } } // Subscribe registers to receive messages from the broker. func (b *Broker) Subscribe() chan interface{} { - if b == nil || b.ctx == nil { - fmt.Printf("foo\n") - } - if b.ctx.Err() == nil { - msgCh := make(chan interface{}, 1) - b.subCh <- msgCh - + msgCh := make(chan interface{}) + select { + case <-b.ctx.Done(): + return nil + case b.subCh <- msgCh: return msgCh } - - return nil } // Unsubscribe de-registers a message receiver. func (b *Broker) Unsubscribe(msgCh chan interface{}) { - if b.ctx.Err() == nil { - b.unsubCh <- msgCh + select { + case <-b.ctx.Done(): + case b.unsubCh <- msgCh: } - close(msgCh) } // Publish sends a message to all subscribers. @@ -88,8 +90,9 @@ func (b *Broker) Publish(msg interface{}) error { if reflect.TypeOf(msg) != b.msgType { return fmt.Errorf("messages to broker must be of type %s", b.msgType.String()) } - if b.ctx.Err() == nil { - b.publishCh <- msg + select { + case <-b.ctx.Done(): + case b.publishCh <- msg: } return nil diff --git a/pkg/utils/cancel_with_err_context.go b/pkg/utils/cancel_with_err_context.go deleted file mode 100644 index 6259db260..000000000 --- a/pkg/utils/cancel_with_err_context.go +++ /dev/null @@ -1,82 +0,0 @@ -package utils - -import ( - "context" - "sync" - "time" -) - -// CancelWithErrFunc is like a regular context.CancelFunc, but you can specify an error to return. -type CancelWithErrFunc func(err error) - -// CancelWithErrContext is a context that can be cancelled with a specific error return. -type CancelWithErrContext struct { - parentCtx context.Context - errChan chan error - doneChan chan struct{} - closeOnce sync.Once - err error -} - -// ContextWithCancelWithErr returns a context and a CancelWithErrFunc. This functions like a normal -// context cancel function, except you can specify what error should be returned. -func ContextWithCancelWithErr(parent context.Context) (*CancelWithErrContext, CancelWithErrFunc) { - cwe := &CancelWithErrContext{ - parentCtx: parent, - errChan: make(chan error), - doneChan: make(chan struct{}), - closeOnce: sync.Once{}, - } - go func() { - for { - select { - case <-parent.Done(): - cwe.closeDoneChan() - - return - case err := <-cwe.errChan: - cwe.err = err - if err != nil { - cwe.closeDoneChan() - - return - } - } - } - }() - - return cwe, func(err error) { - cwe.err = err - cwe.closeDoneChan() - } -} - -func (cwe *CancelWithErrContext) closeDoneChan() { - cwe.closeOnce.Do(func() { - close(cwe.doneChan) - }) -} - -// Done implements Context.Done(). -func (cwe *CancelWithErrContext) Done() <-chan struct{} { - return cwe.doneChan -} - -// Err implements Context.Err(). -func (cwe *CancelWithErrContext) Err() error { - if cwe.err != nil { - return cwe.err - } - - return cwe.parentCtx.Err() -} - -// Deadline implements Context.Deadline(). -func (cwe *CancelWithErrContext) Deadline() (time time.Time, ok bool) { - return cwe.parentCtx.Deadline() -} - -// Value implements Context.Value(). -func (cwe *CancelWithErrContext) Value(key interface{}) interface{} { - return cwe.parentCtx.Value(key) -} diff --git a/pkg/utils/readstring_context.go b/pkg/utils/readstring_context.go index f839d496a..1b2f98f53 100644 --- a/pkg/utils/readstring_context.go +++ b/pkg/utils/readstring_context.go @@ -15,7 +15,7 @@ type readStringResult = struct { // important for callers to error out of further use of the bufio. Also, the goroutine will not // exit until the bufio's underlying connection is closed. func ReadStringContext(ctx context.Context, reader *bufio.Reader, delim byte) (string, error) { - result := make(chan *readStringResult) + result := make(chan *readStringResult, 1) go func() { str, err := reader.ReadString(delim) result <- &readStringResult{ diff --git a/pkg/workceptor/command.go b/pkg/workceptor/command.go index d42b571a0..7c3d2d924 100644 --- a/pkg/workceptor/command.go +++ b/pkg/workceptor/command.go @@ -95,6 +95,7 @@ func commandRunner(command string, params string, unitdir string) error { } doneChan := make(chan bool, 1) go cmdWaiter(cmd, doneChan) + writeStatusFailures := 0 loop: for { select { @@ -111,6 +112,13 @@ loop: err = status.UpdateBasicStatus(statusFilename, WorkStateRunning, fmt.Sprintf("Running: PID %d", cmd.Process.Pid), stdoutSize(unitdir)) if err != nil { logger.Error("Error updating status file %s: %s", statusFilename, err) + writeStatusFailures++ + if writeStatusFailures > 3 { + logger.Error("Exceeded retries for updating status file %s: %s", statusFilename, err) + os.Exit(-1) + } + } else { + writeStatusFailures = 0 } } } diff --git a/pkg/workceptor/controlsvc.go b/pkg/workceptor/controlsvc.go index 167794c36..f0b3a6d0e 100644 --- a/pkg/workceptor/controlsvc.go +++ b/pkg/workceptor/controlsvc.go @@ -4,6 +4,7 @@ package workceptor import ( + "context" "fmt" "os" "path" @@ -211,7 +212,7 @@ func (c *workceptorCommand) processSignature(workType, signature string, connIsU } // Worker function called by the control service to process a "work" command. -func (c *workceptorCommand) ControlFunc(nc *netceptor.Netceptor, cfo controlsvc.ControlFuncOperations) (map[string]interface{}, error) { +func (c *workceptorCommand) ControlFunc(ctx context.Context, nc *netceptor.Netceptor, cfo controlsvc.ControlFuncOperations) (map[string]interface{}, error) { addr := cfo.RemoteAddr() connIsUnix := false if addr.Network() == "unix" { @@ -411,11 +412,8 @@ func (c *workceptorCommand) ControlFunc(nc *netceptor.Netceptor, cfo controlsvc. if err != nil { return nil, err } - doneChan := make(chan struct{}) - defer func() { - close(doneChan) - }() - resultChan, err := c.w.GetResults(unitid, startPos, doneChan) + + resultChan, err := c.w.GetResults(ctx, unitid, startPos) if err != nil { return nil, err } diff --git a/pkg/workceptor/kubernetes.go b/pkg/workceptor/kubernetes.go index 4eb273538..ae4878336 100644 --- a/pkg/workceptor/kubernetes.go +++ b/pkg/workceptor/kubernetes.go @@ -209,6 +209,9 @@ func (kw *kubeUnit) createPod(env map[string]string) error { ctxPodReady, _ = context.WithTimeout(kw.ctx, kw.podPendingTimeout) } ev, err := watch2.UntilWithSync(ctxPodReady, lw, &corev1.Pod{}, nil, podRunningAndReady) + if ev == nil || ev.Object == nil { + return fmt.Errorf("did not return an event while watching pod for work unit %s", kw.ID()) + } var ok bool kw.pod, ok = ev.Object.(*corev1.Pod) if !ok { @@ -234,9 +237,6 @@ func (kw *kubeUnit) createPod(env map[string]string) error { return err } - if ev == nil { - return fmt.Errorf("pod disappeared during watch") - } return nil } diff --git a/pkg/workceptor/remote_work.go b/pkg/workceptor/remote_work.go index 2c1d8013b..26a8109cd 100644 --- a/pkg/workceptor/remote_work.go +++ b/pkg/workceptor/remote_work.go @@ -14,7 +14,6 @@ import ( "path" "regexp" "strings" - "sync" "time" "github.com/ansible/receptor/pkg/logger" @@ -62,12 +61,12 @@ func (rw *remoteUnit) connectToRemote(ctx context.Context) (net.Conn, *bufio.Rea ctxChild, _ := context.WithTimeout(ctx, 5*time.Second) hello, err := utils.ReadStringContext(ctxChild, reader, '\n') if err != nil { - conn.Close() + conn.CloseConnection() return nil, nil, err } if !strings.Contains(hello, red.RemoteNode) { - conn.Close() + conn.CloseConnection() return nil, nil, fmt.Errorf("while expecting node ID %s, got message: %s", red.RemoteNode, strings.TrimRight(hello, "\n")) @@ -144,16 +143,7 @@ func (rw *remoteUnit) getConnectionAndRun(ctx context.Context, firstTimeSync boo // startRemoteUnit makes a single attempt to start a remote unit. func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { - closeOnce := sync.Once{} - doClose := func() error { - var err error - closeOnce.Do(func() { - err = conn.Close() - }) - - return err - } - defer doClose() + defer conn.(interface{ CloseConnection() error }).CloseConnection() red := rw.UnredactedStatus().ExtraData.(*remoteExtraData) workSubmitCmd := make(map[string]interface{}) for k, v := range red.RemoteParams { @@ -184,8 +174,6 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader } response, err := utils.ReadStringContext(ctx, reader, '\n') if err != nil { - conn.Close() - return fmt.Errorf("read error reading from %s: %s", red.RemoteNode, err) } submitIDRegex := regexp.MustCompile(`with ID ([a-zA-Z0-9]+)\.`) @@ -206,14 +194,12 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader if err != nil { return fmt.Errorf("error sending stdin file: %s", err) } - err = doClose() + err = conn.Close() if err != nil { return fmt.Errorf("error closing stdin file: %s", err) } response, err = utils.ReadStringContext(ctx, reader, '\n') if err != nil { - conn.Close() - return fmt.Errorf("read error reading from %s: %s", red.RemoteNode, err) } resultErrorRegex := regexp.MustCompile("ERROR: (.*)") @@ -232,7 +218,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader // cancelOrReleaseRemoteUnit makes a single attempt to cancel or release a remote unit. func (rw *remoteUnit) cancelOrReleaseRemoteUnit(ctx context.Context, conn net.Conn, reader *bufio.Reader, release bool, force bool) error { - defer conn.Close() + defer conn.(interface{ CloseConnection() error }).CloseConnection() red := rw.Status().ExtraData.(*remoteExtraData) var workCmd string if release { @@ -262,8 +248,6 @@ func (rw *remoteUnit) cancelOrReleaseRemoteUnit(ctx context.Context, conn net.Co } response, err := utils.ReadStringContext(ctx, reader, '\n') if err != nil { - conn.Close() - return fmt.Errorf("read error reading from %s: %s", red.RemoteNode, err) } if response[:5] == "ERROR" { @@ -289,9 +273,15 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) remoteNode := red.RemoteNode remoteUnitID := red.RemoteUnitID conn, reader := rw.getConnection(mw) + defer func() { + if conn != nil { + conn.(interface{ CloseConnection() error }).CloseConnection() + } + }() if conn == nil { return } + writeStatusFailures := 0 for { if conn == nil { conn, reader = rw.getConnection(mw) @@ -302,7 +292,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) _, err := conn.Write([]byte(fmt.Sprintf("work status %s\n", remoteUnitID))) if err != nil { logger.Debug("Write error sending to %s: %s\n", remoteUnitID, err) - _ = conn.Close() + _ = conn.(interface{ CloseConnection() error }).CloseConnection() conn = nil continue @@ -310,7 +300,7 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) status, err := utils.ReadStringContext(mw, reader, '\n') if err != nil { logger.Debug("Read error reading from %s: %s\n", remoteNode, err) - _ = conn.Close() + _ = conn.(interface{ CloseConnection() error }).CloseConnection() conn = nil continue @@ -339,6 +329,16 @@ func (rw *remoteUnit) monitorRemoteStatus(mw *utils.JobContext, forRelease bool) return } rw.UpdateBasicStatus(si.State, si.Detail, si.StdoutSize) + if rw.LastUpdateError() != nil { + writeStatusFailures++ + if writeStatusFailures > 3 { + logger.Error("Exceeded retries for updating status file for work unit %s", rw.unitID) + + return + } + } else { + writeStatusFailures = 0 + } if err != nil { logger.Error("Error saving local status file: %s\n", err) @@ -393,10 +393,18 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { status := rw.Status() diskStdoutSize := stdoutSize(rw.UnitDir()) remoteStdoutSize := status.StdoutSize + if status.State == WorkStateFailed { + return + } if IsComplete(status.State) && diskStdoutSize >= remoteStdoutSize { return } else if diskStdoutSize < remoteStdoutSize { conn, reader := rw.getConnection(mw) + defer func() { + if conn != nil { + _ = conn.(interface{ CloseConnection() error }).CloseConnection() + } + }() if conn == nil { return } @@ -454,7 +462,7 @@ func (rw *remoteUnit) monitorRemoteStdout(mw *utils.JobContext) { if ok { cr.CancelRead() } - _ = conn.Close() + _ = conn.(interface{ CloseConnection() error }).CloseConnection() return } @@ -649,7 +657,7 @@ func (rw *remoteUnit) cancelOrRelease(release bool, force bool) error { } rw.topJC.NewJob(rw.w.ctx, 1, false) - return rw.runAndMonitor(rw.topJC, true, func(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { + return rw.runAndMonitor(rw.topJC, release, func(ctx context.Context, conn net.Conn, reader *bufio.Reader) error { return rw.cancelOrReleaseRemoteUnit(ctx, conn, reader, release, false) }) } diff --git a/pkg/workceptor/workceptor.go b/pkg/workceptor/workceptor.go index c9af66ebc..705566c57 100644 --- a/pkg/workceptor/workceptor.go +++ b/pkg/workceptor/workceptor.go @@ -445,29 +445,44 @@ func sleepOrDone(doneChan <-chan struct{}, interval time.Duration) bool { } // GetResults returns a live stream of the results of a unit. -func (w *Workceptor) GetResults(unitID string, startPos int64, doneChan chan struct{}) (chan []byte, error) { +func (w *Workceptor) GetResults(ctx context.Context, unitID string, startPos int64) (chan []byte, error) { unit, err := w.findUnit(unitID) if err != nil { return nil, err } resultChan := make(chan []byte) + closeOnce := sync.Once{} + resultClose := func() { + closeOnce.Do(func() { + close(resultChan) + }) + } + unitdir := path.Join(w.dataDir, unitID) + stdoutFilename := path.Join(unitdir, "stdout") + var stdout *os.File + ctxChild, cancel := context.WithCancel(ctx) go func() { - unitdir := path.Join(w.dataDir, unitID) - stdoutFilename := path.Join(unitdir, "stdout") + defer func() { + err = stdout.Close() + if err != nil { + logger.Error("Error closing stdout %s", stdoutFilename) + } + resultClose() + cancel() + }() + // Wait for stdout file to exist for { - _, err := os.Stat(stdoutFilename) - + stdout, err = os.Open(stdoutFilename) switch { case err == nil: case os.IsNotExist(err): if IsComplete(unit.Status().State) { - close(resultChan) logger.Warning("Unit completed without producing any stdout\n") return } - if sleepOrDone(doneChan, 250*time.Millisecond) { + if sleepOrDone(ctx.Done(), 500*time.Millisecond) { return } @@ -480,57 +495,76 @@ func (w *Workceptor) GetResults(unitID string, startPos int64, doneChan chan str break } - var stdout *os.File - var err error filePos := startPos + statChan := make(chan struct{}, 1) + go func() { + failures := 0 + for { + select { + case <-ctxChild.Done(): + return + case <-time.After(1 * time.Second): + _, err := os.Stat(stdoutFilename) + if os.IsNotExist(err) { + failures++ + if failures > 3 { + logger.Error("Exceeded retries for reading stdout %s", stdoutFilename) + statChan <- struct{}{} + + return + } + } else { + failures = 0 + } + } + } + }() for { - if sleepOrDone(doneChan, 250*time.Millisecond) { + if sleepOrDone(ctx.Done(), 250*time.Millisecond) { return } - if stdout == nil { - stdout, err = os.Open(stdoutFilename) - if err != nil { - continue - } - } - for err == nil { - var newPos int64 - newPos, err = stdout.Seek(filePos, 0) - if err != nil { - logger.Warning("Seek error processing stdout: %s\n", err) - + for { + select { + case <-ctx.Done(): return - } - if newPos != filePos { - logger.Warning("Seek error processing stdout\n") - + case <-statChan: return + default: + var newPos int64 + newPos, err = stdout.Seek(filePos, 0) + if err != nil { + logger.Warning("Seek error processing stdout: %s\n", err) + + return + } + if newPos != filePos { + logger.Warning("Seek error processing stdout\n") + + return + } + var n int + buf := make([]byte, utils.NormalBufferSize) + n, err = stdout.Read(buf) + if n > 0 { + filePos += int64(n) + select { + case <-ctx.Done(): + return + case resultChan <- buf[:n]: + } + } } - var n int - buf := make([]byte, utils.NormalBufferSize) - n, err = stdout.Read(buf) - if n > 0 { - filePos += int64(n) - resultChan <- buf[:n] + if err != nil { + break } } if err == io.EOF { - err = stdout.Close() - if err != nil { - logger.Error("Error closing stdout\n") - - return - } - stdout = nil stdoutSize := stdoutSize(unitdir) if IsComplete(unit.Status().State) && stdoutSize >= unit.Status().StdoutSize { - close(resultChan) logger.Info("Stdout complete - closing channel for: %s \n", unitID) return } - - continue } else if err != nil { logger.Error("Error reading stdout: %s\n", err)