From 5b3f2cdfa546f89c26b45d74b91be2f88d8905a0 Mon Sep 17 00:00:00 2001 From: Dean Lee Date: Sat, 18 Jan 2025 11:26:13 +0800 Subject: [PATCH] msgq: refactor blocking recv for improved robustness and performance (#616) * improve blocking receive * Update msgq/impl_msgq.cc * 1s timeout * comment * typo --------- Co-authored-by: Shane Smiskol --- SConscript | 5 +- msgq/impl_msgq.cc | 104 +++++++++++++++-------------------- msgq/ipc.pxd | 2 +- msgq/ipc_pyx.pyx | 9 +-- msgq/tests/test_messaging.py | 26 +++++++++ 5 files changed, 78 insertions(+), 68 deletions(-) diff --git a/SConscript b/SConscript index 147eb3042..cc2dcb842 100644 --- a/SConscript +++ b/SConscript @@ -15,7 +15,7 @@ msgq_objects = env.SharedObject([ 'msgq/msgq.cc', ]) msgq = env.Library('msgq', msgq_objects) -msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", common]) +msgq_python = envCython.Program('msgq/ipc_pyx.so', 'msgq/ipc_pyx.pyx', LIBS=envCython["LIBS"]+[msgq, "zmq", 'pthread', common]) # Build Vision IPC vipc_files = ['visionipc.cc', 'visionipc_server.cc', 'visionipc_client.cc', 'visionbuf.cc'] @@ -31,7 +31,7 @@ visionipc = env.Library('visionipc', vipc_objects) vipc_frameworks = [] -vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq"] +vipc_libs = envCython["LIBS"] + [visionipc, msgq, common, "zmq", 'pthread'] if arch == "Darwin": vipc_frameworks.append('OpenCL') else: @@ -45,4 +45,5 @@ if GetOption('extras'): [f'{visionipc_dir.abspath}/test_runner.cc', f'{visionipc_dir.abspath}/visionipc_tests.cc'], LIBS=['pthread'] + vipc_libs, FRAMEWORKS=vipc_frameworks) +msgq = [msgq, 'pthread'] Export('visionipc', 'msgq', 'msgq_python') diff --git a/msgq/impl_msgq.cc b/msgq/impl_msgq.cc index b23991351..06eb903a9 100644 --- a/msgq/impl_msgq.cc +++ b/msgq/impl_msgq.cc @@ -1,20 +1,12 @@ #include #include #include -#include +#include #include -#include #include "msgq/impl_msgq.h" - -volatile sig_atomic_t msgq_do_exit = 0; - -void sig_handler(int signal) { - assert(signal == SIGINT || signal == SIGTERM); - msgq_do_exit = 1; -} - +using namespace std::chrono; MSGQContext::MSGQContext() { } @@ -70,61 +62,55 @@ int MSGQSubSocket::connect(Context *context, std::string endpoint, std::string a return 0; } - -Message * MSGQSubSocket::receive(bool non_blocking){ - msgq_do_exit = 0; - - void (*prev_handler_sigint)(int); - void (*prev_handler_sigterm)(int); - if (!non_blocking){ - prev_handler_sigint = std::signal(SIGINT, sig_handler); - prev_handler_sigterm = std::signal(SIGTERM, sig_handler); - } - - msgq_msg_t msg; - - MSGQMessage *r = NULL; - +Message *MSGQSubSocket::receive(bool non_blocking) { + msgq_msg_t msg{}; int rc = msgq_msg_recv(&msg, q); - // Hack to implement blocking read with a poller. Don't use this - while (!non_blocking && rc == 0 && msgq_do_exit == 0){ - msgq_pollitem_t items[1]; - items[0].q = q; - - int t = (timeout != -1) ? timeout : 100; - - int n = msgq_poll(items, 1, t); - rc = msgq_msg_recv(&msg, q); - - // The poll indicated a message was ready, but the receive failed. Try again - if (n == 1 && rc == 0){ - continue; - } - - if (timeout != -1){ - break; + if (rc == 0 && !non_blocking) { + sigset_t mask; + sigset_t old_mask; + sigemptyset(&mask); + sigaddset(&mask, SIGINT); + sigaddset(&mask, SIGTERM); + sigaddset(&mask, SIGUSR2); // notification from publisher + + pthread_sigmask(SIG_BLOCK, &mask, &old_mask); + + int64_t timeout_ns = ((timeout != -1) ? timeout : 100) * 1000000; + auto start = steady_clock::now(); + + // Continue receiving messages until timeout or interruption by SIGINT or SIGTERM + while (rc == 0 && timeout_ns > 0) { + struct timespec ts { + timeout_ns / 1000000000, + timeout_ns % 1000000000, + }; + + int ret = sigtimedwait(&mask, nullptr, &ts); + if (ret == SIGINT || ret == SIGTERM) { + // Ensure signal handling is not missed + raise(ret); + break; + } else if (ret == -1 && errno == EAGAIN && timeout != -1) { + break; // Timed out + } + + rc = msgq_msg_recv(&msg, q); + + if (timeout != -1) { + timeout_ns -= duration_cast(steady_clock::now() - start).count(); + start = steady_clock::now(); // Update start time + } } + pthread_sigmask(SIG_SETMASK, &old_mask, nullptr); } - - if (!non_blocking){ - std::signal(SIGINT, prev_handler_sigint); - std::signal(SIGTERM, prev_handler_sigterm); - } - - errno = msgq_do_exit ? EINTR : 0; - - if (rc > 0){ - if (msgq_do_exit){ - msgq_msg_close(&msg); // Free unused message on exit - } else { - r = new MSGQMessage; - r->takeOwnership(msg.data, msg.size); - } + if (rc > 0) { + MSGQMessage *r = new MSGQMessage; + r->takeOwnership(msg.data, msg.size); + return r; } - - return (Message*)r; + return nullptr; } void MSGQSubSocket::setTimeout(int t){ diff --git a/msgq/ipc.pxd b/msgq/ipc.pxd index 2c7ac963e..ca33ea0f8 100644 --- a/msgq/ipc.pxd +++ b/msgq/ipc.pxd @@ -50,7 +50,7 @@ cdef extern from "msgq/ipc.h": @staticmethod SubSocket * create() int connect(Context *, string, string, bool) - Message * receive(bool) + Message * receive(bool) nogil void setTimeout(int) cdef cppclass PubSocket: diff --git a/msgq/ipc_pyx.pyx b/msgq/ipc_pyx.pyx index d8797f395..a9a9a422f 100644 --- a/msgq/ipc_pyx.pyx +++ b/msgq/ipc_pyx.pyx @@ -196,14 +196,11 @@ cdef class SubSocket: self.socket.setTimeout(timeout) def receive(self, bool non_blocking=False): - msg = self.socket.receive(non_blocking) + cdef cppMessage *msg + with nogil: + msg = self.socket.receive(non_blocking) if msg == NULL: - # If a blocking read returns no message check errno if SIGINT was caught in the C++ code - if errno.errno == errno.EINTR: - print("SIGINT received, exiting") - sys.exit(1) - return None else: sz = msg.getSize() diff --git a/msgq/tests/test_messaging.py b/msgq/tests/test_messaging.py index 40dfd7f00..c9b1c646a 100644 --- a/msgq/tests/test_messaging.py +++ b/msgq/tests/test_messaging.py @@ -1,5 +1,8 @@ import os +import pytest import random +import signal +import threading import time import string import msgq @@ -67,3 +70,26 @@ def test_receive_timeout(self): recvd = sub_sock.receive() assert (time.monotonic() - start_time) < 0.2 assert recvd is None + + def test_receive_interrupts_on_sigint(self): + sock = random_sock() + sub_sock = msgq.sub_sock(sock) + sub_sock.setTimeout(1000) + + # Send SIGINT after a short delay + pid = os.getpid() + def send_sigint(): + time.sleep(.5) + os.kill(pid, signal.SIGINT) + + # Start a thread to send SIGINT + thread = threading.Thread(target=send_sigint) + thread.start() + + with pytest.raises(KeyboardInterrupt): + start_time = time.monotonic() + recvd = sub_sock.receive() + assert (time.monotonic() - start_time) < 0.5 + assert recvd is None + + thread.join()