diff --git a/runusb/__main__.py b/runusb/__main__.py index dc5e3b2..5cadfbd 100755 --- a/runusb/__main__.py +++ b/runusb/__main__.py @@ -10,8 +10,9 @@ import signal import subprocess import sys -import time +import uuid from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field from enum import Enum, IntEnum, unique from threading import Thread from typing import IO, Iterator, NamedTuple, Type @@ -26,6 +27,7 @@ try: from logger_extras import MQTTHandler # type: ignore[attr-defined] + from paho.mqtt.client import Client as MQTTClient # type: ignore[import-untyped,unused-ignore] except ImportError: MQTTHandler = None @@ -45,14 +47,23 @@ # the directory under which all USBs will be mounted MOUNTPOINT_DIR = os.environ.get('RUNUSB_MOUNTPOINT_DIR', '/media') -# This will be populated if we have the config file -# url format: mqtt[s]://[[:]@][:]/ -MQTT_URL = None -MQTT_TOPIC_ROOT = '' -MQTT_CLIENT = None MQTT_CONFIG_FILE = '/etc/sbot/mqtt.conf' +@dataclass +class MqttSettings: + # url format: mqtt[s]://[[:]@][:]/ + url: str | None = None + active_config: MQTTVariables | None = None + client: MQTTClient | None = None + active_usercode: RobotUSBHandler | None = None + extra_data: dict[str, str] = field(default_factory=lambda: {"run_uuid": ""}) + + +# This will be populated if we have the config file +MQTT_SETTINGS = MqttSettings() + + class MQTTVariables(NamedTuple): host: str port: int | None @@ -159,9 +170,11 @@ def set_status(self, value: LedStatus) -> None: GPIO.output(self.LEDs.STATUS_BLUE, GPIO.HIGH if value.value[2] else GPIO.LOW) # Also send the status over MQTT - if MQTT_CLIENT is not None: - MQTT_CLIENT.publish( - f'{MQTT_TOPIC_ROOT}/state', + mqtt_client = MQTT_SETTINGS.client + if mqtt_client is not None and MQTT_SETTINGS.active_config is not None: + topic_prefix = MQTT_SETTINGS.active_config.topic_prefix + mqtt_client.publish( + f'{topic_prefix}/state', json.dumps({"state": value.name}), qos=1, retain=True, @@ -171,6 +184,56 @@ def set_status(self, value: LedStatus) -> None: LED_CONTROLLER = LEDController() +def mqtt_on_stop_action(client, userdata, message): + LOGGER.info("Received stop action") + try: + payload = json.loads(message.payload) + except json.JSONDecodeError: + LOGGER.warning("Failed to decode stop action message.") + return + + if payload.get('pressed') is not True: + LOGGER.info("Stop action had incorrect payload, ignoring.") + return + + if MQTT_SETTINGS.active_usercode is not None: + # Run the cleanup function to stop the usercode but allow it to be + # restarted without reinserting the USB + MQTT_SETTINGS.active_usercode.killed = True + MQTT_SETTINGS.active_usercode.cleanup() + + +def mqtt_on_reset_action(client, userdata, message): + LOGGER.info("Received reset action") + try: + payload = json.loads(message.payload) + except json.JSONDecodeError: + LOGGER.warning("Failed to decode reset action message.") + return + + if payload.get('pressed') is not True: + LOGGER.info("Reset action had incorrect payload, ignoring.") + return + + if MQTT_SETTINGS.active_usercode is not None: + # The reset function will stop the usercode and wait for it to finish, + # if it was running, before restarting it + MQTT_SETTINGS.active_usercode.reset() + + +def mqtt_connected_actions(): + """Actions to perform when the MQTT client connects.""" + LED_CONTROLLER.set_wifi(True) + if MQTT_SETTINGS.client is not None: + mqtt_client = MQTT_SETTINGS.client + assert MQTT_SETTINGS.active_config is not None + topic_prefix = MQTT_SETTINGS.active_config.topic_prefix + mqtt_client.message_callback_add(f"{topic_prefix}/stop", mqtt_on_stop_action) + mqtt_client.message_callback_add(f"{topic_prefix}/reset", mqtt_on_reset_action) + mqtt_client.subscribe(f"{topic_prefix}/stop", qos=1) + mqtt_client.subscribe(f"{topic_prefix}/reset", qos=1) + + @unique class USBType(Enum): ROBOT = 'ROBOT' @@ -246,27 +309,42 @@ def close(self) -> None: class RobotUSBHandler(USBHandler): def __init__(self, mountpoint_path: str) -> None: + self.mountpoint_path = mountpoint_path + + if MQTT_SETTINGS.active_usercode is not None: + raise RuntimeError("There is already a usercode running") + else: + MQTT_SETTINGS.active_usercode = self + self._setup_logging(mountpoint_path) LED_CONTROLLER.set_code(True) - LED_CONTROLLER.set_status(LedStatus.Running) - env = dict(os.environ) - env["SBOT_METADATA_PATH"] = MOUNTPOINT_DIR - if MQTT_URL is not None: + self.env = dict(os.environ) + self.env["SBOT_METADATA_PATH"] = MOUNTPOINT_DIR + if MQTT_SETTINGS.url is not None: # pass the mqtt url to the robot for camera images - env["SBOT_MQTT_URL"] = MQTT_URL + self.env["SBOT_MQTT_URL"] = MQTT_SETTINGS.url + + self.start() + + def start(self) -> None: + run_uuid = uuid.uuid4().hex + MQTT_SETTINGS.extra_data["run_uuid"] = run_uuid + self.env["run_uuid"] = run_uuid + self.killed = False + REL_TIME_FILTER.reset_time_reference() # type: ignore[union-attr] + LED_CONTROLLER.set_status(LedStatus.Running) self.process = subprocess.Popen( [sys.executable, '-u', ROBOT_FILE], stdin=subprocess.DEVNULL, stderr=subprocess.STDOUT, stdout=subprocess.PIPE, bufsize=1, # line buffered - cwd=mountpoint_path, - env=env, + cwd=self.mountpoint_path, + env=self.env, text=True, start_new_session=True, # Put the process in a new process group ) - self.process_start_time = time.time() self.thread = Thread(target=self._watch_process) self.thread.start() @@ -288,6 +366,15 @@ def close(self) -> None: LED_CONTROLLER.set_status(LedStatus.NoUSB) LED_CONTROLLER.set_code(False) USERCODE_LOGGER.removeHandler(self.handler) + MQTT_SETTINGS.extra_data["run_uuid"] = "" # Reset the run UUID + MQTT_SETTINGS.active_usercode = None + + def reset(self) -> None: + self.cleanup() + # Wait for the process to finish + self.process.wait() + self.log_thread.join() + self.start() def _send_signal(self, sig: int) -> None: if self.process.poll() is not None: @@ -298,20 +385,16 @@ def _send_signal(self, sig: int) -> None: def _watch_process(self) -> None: # Wait for the process to complete self.process.wait() - if self.process.returncode != 0: + if self.killed: + USERCODE_LOGGER.warning("Your code was stopped.") + LED_CONTROLLER.set_status(LedStatus.Killed) + elif self.process.returncode != 0: USERCODE_LOGGER.warning(f"Process exited with code {self.process.returncode}") LED_CONTROLLER.set_status(LedStatus.Crashed) else: USERCODE_LOGGER.info("Your code finished successfully.") LED_CONTROLLER.set_status(LedStatus.Finished) - process_lifetime = time.time() - self.process_start_time - - # If the process was alive for less than a second, delay the clean-up. - # This ensures the LEDs stay on for a noticeable amount of time. - if process_lifetime < 1: - time.sleep(1 - process_lifetime) - # Start clean-up self.cleanup() @@ -398,9 +481,13 @@ def _detect_new_mountpoint_path(self, path: str) -> None: usb_type = detect_usb_type(path) LOGGER.info(f"Found new mountpoint: {path} ({usb_type})") handler_class = self.TYPE_HANDLERS[usb_type] - handler = handler_class(path) - LOGGER.info(" -> launched handler") - self.mountpoint_handlers[path] = handler + try: + handler = handler_class(path) + except RuntimeError as e: + LOGGER.error(f"Failed to launch handler: {e}") + else: + LOGGER.info(" -> launched handler") + self.mountpoint_handlers[path] = handler def _detect_dead_mountpoint_path(self, path: str) -> None: LOGGER.info(f"Lost mountpoint: {path}") @@ -423,7 +510,6 @@ def _is_viable_mountpoint(self, mountpoint: Mountpoint) -> bool: def set_mqtt_url(config: MQTTVariables) -> None: - global MQTT_URL if config.username is not None and config.password is not None: auth = f"{config.username}:{config.password}@" elif config.username is not None: @@ -434,7 +520,7 @@ def set_mqtt_url(config: MQTTVariables) -> None: port_str = (f":{config.port}" if config.port is not None else "") scheme = 'mqtts' if config.use_tls else 'mqtt' - MQTT_URL = ( + MQTT_SETTINGS.url = ( f"{scheme}://{auth}{config.host}{port_str}/{config.topic_prefix}" ) @@ -467,7 +553,7 @@ def read_mqtt_config_file() -> MQTTVariables | None: def setup_usercode_logging() -> None: - global REL_TIME_FILTER, MQTT_CLIENT, MQTT_TOPIC_ROOT + global REL_TIME_FILTER REL_TIME_FILTER = RelativeTimeFilter() USERCODE_LOGGER.addFilter(REL_TIME_FILTER) USERCODE_LOGGER.setLevel(logging.DEBUG) @@ -475,6 +561,7 @@ def setup_usercode_logging() -> None: if MQTTHandler is not None: # If we have relative logging, we should also have the MQTT handler mqtt_config = read_mqtt_config_file() + MQTT_SETTINGS.active_config = mqtt_config if mqtt_config is not None: handler = MQTTHandler( @@ -485,11 +572,11 @@ def setup_usercode_logging() -> None: username=mqtt_config.username, password=mqtt_config.password, connected_topic=f"{mqtt_config.topic_prefix}/connected", - connected_callback=lambda: LED_CONTROLLER.set_wifi(True), + connected_callback=mqtt_connected_actions, disconnected_callback=lambda: LED_CONTROLLER.set_wifi(False), + extra_data=MQTT_SETTINGS.extra_data, ) - MQTT_CLIENT = handler.mqtt - MQTT_TOPIC_ROOT = mqtt_config.topic_prefix + MQTT_SETTINGS.client = handler.mqtt handler.setLevel(logging.INFO) handler.setFormatter(TieredFormatter( diff --git a/setup.cfg b/setup.cfg index 37f9686..66d1ddf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,12 +40,13 @@ ignore = W503 # try to keep it below 80, but this allows us to push it a bit when needed. -max_line_length = 90 +max_line_length = 100 [isort] atomic = true balanced_wrapping = true +line_length = 100 default_section = THIRDPARTY sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER