diff --git a/neurobooth_os/gui.py b/neurobooth_os/gui.py index c937877c..e0e7af67 100644 --- a/neurobooth_os/gui.py +++ b/neurobooth_os/gui.py @@ -125,10 +125,9 @@ def _create_session_dict(window, log_task, staff_id, subject: Subject, tasks): ########## Task-related functions ############ -def _start_task_presentation(window, tasks: List[str], subject_id: str, session_id: int, steps): +def _start_task_presentation(window, tasks: List[str], subject_id: str, session_id: int, steps, conn): """Present tasks""" global last_task - conn = meta.get_database_connection() window['Start'].update(disabled=True) write_output(window, "\nSession started") last_task = tasks[-1] @@ -139,7 +138,7 @@ def _start_task_presentation(window, tasks: List[str], subject_id: str, session_ destination='STM', body=msg_body ) - meta.post_message(msg, conn), conn + meta.post_message(msg, conn) steps.append("task_started") else: sg.PopupError("No task selected", location=get_popup_location(window)) @@ -320,76 +319,76 @@ def _start_ctr_server(window, logger): def _start_ctr_msg_reader(logger, window): - db_conn = meta.get_database_connection() - while True: - message: Message = meta.read_next_message("CTR", conn=db_conn) - if message is None: - time.sleep(.25) - continue - msg_body: Optional[MsgBody] = None - logger.info(f'MESSAGE RECEIVED: {message.model_dump_json()}') - logger.info(f'MESSAGE RECEIVED: {message.body.model_dump_json()}') - - if "DeviceInitialization" == message.msg_type: - msg_body: DeviceInitialization = message.body - outlet_name = msg_body.stream_name - outlet_id = msg_body.outlet_id - outlet_values = f"['{outlet_name}', '{outlet_id}']" - window.write_event_value("-OUTLETID-", outlet_values) - elif "SessionPrepared" == message.msg_type: - window.write_event_value("devices_connected", True) - elif "ServerStarted" == message.msg_type: - window.write_event_value("server_started", message.source) - elif "TasksCreated" == message.msg_type: - window.write_event_value("tasks_created", "") - elif "TaskInitialization" == message.msg_type: - msg_body: TaskInitialization = message.body - task_id = msg_body.task_id - log_task_id = msg_body.log_task_id - tsk_strt_time = msg_body.tsk_start_time - window.write_event_value( - "task_initiated", - f"['{task_id}', '{task_id}', '{log_task_id}', '{tsk_strt_time}']", - ) - - elif "TaskCompletion" == message.msg_type: - msg_body: TaskCompletion = message.body - task_id = msg_body.task_id - has_lsl_stream = msg_body.has_lsl_stream - event_value = f"['{task_id}', '{has_lsl_stream}']" - logger.debug(f"TaskCompletion msg for {task_id}") - window.write_event_value("task_finished", event_value) - - elif "NewVideoFile" == message.msg_type: - msg_body: NewVideoFile = message.body - event = msg_body.event - stream_name = msg_body.stream_name - filename = msg_body.filename - window.write_event_value(event, f"{stream_name},{filename}") - - elif "NoEyetracker" == message.msg_type: - window.write_event_value( - "no_eyetracker", - "Eyetracker not found! \nServers will be " - + "terminated, wait until servers are closed.\nThen, connect the eyetracker and start again", - ) - - elif "MbientDisconnected" == message.msg_type: - msg_body: MbientDisconnected = message.body - window.write_event_value( - "mbient_disconnected", f"{msg_body.warning}, \nconsider repeating the task" - ) - elif "StatusMessage" == message.msg_type: - write_message_to_output(logger, message, window) - - elif "ErrorMessage" == message.msg_type: - write_message_to_output(logger, message, window) - - elif "FramePreviewReply" == message.msg_type: - frame_reply: FramePreviewReply = message.body - handle_frame_preview_reply(window, frame_reply) - else: - logger.debug(f"Unhandled message: {message.msg_type}") + with meta.get_database_connection() as db_conn: + while True: + message: Message = meta.read_next_message("CTR", conn=db_conn) + if message is None: + time.sleep(.25) + continue + msg_body: Optional[MsgBody] = None + logger.info(f'MESSAGE RECEIVED: {message.model_dump_json()}') + logger.info(f'MESSAGE RECEIVED: {message.body.model_dump_json()}') + + if "DeviceInitialization" == message.msg_type: + msg_body: DeviceInitialization = message.body + outlet_name = msg_body.stream_name + outlet_id = msg_body.outlet_id + outlet_values = f"['{outlet_name}', '{outlet_id}']" + window.write_event_value("-OUTLETID-", outlet_values) + elif "SessionPrepared" == message.msg_type: + window.write_event_value("devices_connected", True) + elif "ServerStarted" == message.msg_type: + window.write_event_value("server_started", message.source) + elif "TasksCreated" == message.msg_type: + window.write_event_value("tasks_created", "") + elif "TaskInitialization" == message.msg_type: + msg_body: TaskInitialization = message.body + task_id = msg_body.task_id + log_task_id = msg_body.log_task_id + tsk_strt_time = msg_body.tsk_start_time + window.write_event_value( + "task_initiated", + f"['{task_id}', '{task_id}', '{log_task_id}', '{tsk_strt_time}']", + ) + + elif "TaskCompletion" == message.msg_type: + msg_body: TaskCompletion = message.body + task_id = msg_body.task_id + has_lsl_stream = msg_body.has_lsl_stream + event_value = f"['{task_id}', '{has_lsl_stream}']" + logger.debug(f"TaskCompletion msg for {task_id}") + window.write_event_value("task_finished", event_value) + + elif "NewVideoFile" == message.msg_type: + msg_body: NewVideoFile = message.body + event = msg_body.event + stream_name = msg_body.stream_name + filename = msg_body.filename + window.write_event_value(event, f"{stream_name},{filename}") + + elif "NoEyetracker" == message.msg_type: + window.write_event_value( + "no_eyetracker", + "Eyetracker not found! \nServers will be " + + "terminated, wait until servers are closed.\nThen, connect the eyetracker and start again", + ) + + elif "MbientDisconnected" == message.msg_type: + msg_body: MbientDisconnected = message.body + window.write_event_value( + "mbient_disconnected", f"{msg_body.warning}, \nconsider repeating the task" + ) + elif "StatusMessage" == message.msg_type: + write_message_to_output(logger, message, window) + + elif "ErrorMessage" == message.msg_type: + write_message_to_output(logger, message, window) + + elif "FramePreviewReply" == message.msg_type: + frame_reply: FramePreviewReply = message.body + handle_frame_preview_reply(window, frame_reply) + else: + logger.debug(f"Unhandled message: {message.msg_type}") def write_message_to_output(logger, message: Request, window): @@ -451,7 +450,7 @@ def _request_frame_preview(conn): meta.post_message(req, conn) -def _prepare_devices(window, nodes: List[str], collection_id: str, log_task: Dict, database, tasks: str): +def _prepare_devices(window, nodes: List[str], collection_id: str, log_task: Dict, database, tasks: str, conn): """Prepare devices. Mainly ensuring devices are connected""" # disable button so it can't be pushed twice @@ -480,7 +479,7 @@ def _prepare_devices(window, nodes: List[str], collection_id: str, log_task: Dic destination=dest, body=body) - meta.post_message(msg, conn=meta.get_database_connection()) + meta.post_message(msg, conn) return video_marker_stream, event, values @@ -501,228 +500,228 @@ def gui(logger): subject: Subject tasks = None - conn = meta.get_database_connection() - meta.clear_msg_queue(conn) - - window = _win_gen(_init_layout, conn) - - plttr = stream_plotter() - log_task = meta._new_tech_log_dict() - log_sess = meta._new_session_log_dict() - stream_ids, inlets = {}, {} - plot_elem, inlet_keys = [], [] - steps = list() # keep track of steps done - event, values = window.read(0.1) - sess_info = None - while True: - event, values = window.read(0.5) - ############################################################ - # Initial Window -> Select subject, study and tasks - ############################################################ - if event == "study_id": - study_id = values[event] - log_sess["study_id"] = study_id - collection_ids = _get_collections(window, study_id) - - elif event == "find_subject": - subject: Subject = _get_subject_by_id(window, log_sess, conn, values["subject_id"]) - - elif event == "collection_id": - collection_id: str = values[event] - log_sess["collection_id"] = collection_id - tasks = _get_tasks(window, collection_id) - - elif event == "_init_sess_save_": - if values["study_id"] == "" or values['collection_id'] == "": - sg.PopupError("Study and Collection are required fields", location=get_popup_location(window)) - elif values["staff_id"] == "": - sg.PopupError("Staff ID is required", location=get_popup_location(window)) - elif window["subject_info"].get() == "": - sg.PopupError("Please select a Subject", location=get_popup_location(window)) - else: - log_sess["staff_id"] = values["staff_id"] - sess_info = _create_session_dict( - window, - log_task, - values["staff_id"], - subject, - tasks, - ) - # Open new layout with main window - window = _win_gen(_main_layout, sess_info) - _start_ctr_server(window, logger) - logger.debug(f"ctr msg reader started") - - ############################################################ - # Main Window -> Run neurobooth session - ############################################################ - - # Start servers on STM, ACQ - elif event == "-init_servs-": - _start_servers(window, nodes) - - # Turn on devices - elif event == "-Connect-": - video_marker_stream, event, values = _prepare_devices(window, - nodes, collection_id, log_task, database, tasks) - - elif event == "plot": - _plot_realtime(window, plttr, inlets) - - elif event == "Start": - if not start_pressed: - window["Start"].Update(disabled=True) - start_pressed = True - session_id = meta._make_session_id(conn, log_sess) - tasks = [k for k, v in values.items() if "obs" in k and v is True] - _start_task_presentation(window, tasks, sess_info["subject_id"], session_id, steps) - - elif event == "tasks_created": - _session_button_state(window, disabled=False) - for task_id in tasks: - msg_body = PerformTaskRequest(task_id=task_id) + with meta.get_database_connection() as conn: + meta.clear_msg_queue(conn) + + window = _win_gen(_init_layout, conn) + + plttr = stream_plotter() + log_task = meta._new_tech_log_dict() + log_sess = meta._new_session_log_dict() + stream_ids, inlets = {}, {} + plot_elem, inlet_keys = [], [] + steps = list() # keep track of steps done + event, values = window.read(0.1) + sess_info = None + while True: + event, values = window.read(0.5) + ############################################################ + # Initial Window -> Select subject, study and tasks + ############################################################ + if event == "study_id": + study_id = values[event] + log_sess["study_id"] = study_id + collection_ids = _get_collections(window, study_id) + + elif event == "find_subject": + subject: Subject = _get_subject_by_id(window, log_sess, conn, values["subject_id"]) + + elif event == "collection_id": + collection_id: str = values[event] + log_sess["collection_id"] = collection_id + tasks = _get_tasks(window, collection_id) + + elif event == "_init_sess_save_": + if values["study_id"] == "" or values['collection_id'] == "": + sg.PopupError("Study and Collection are required fields", location=get_popup_location(window)) + elif values["staff_id"] == "": + sg.PopupError("Staff ID is required", location=get_popup_location(window)) + elif window["subject_info"].get() == "": + sg.PopupError("Please select a Subject", location=get_popup_location(window)) + else: + log_sess["staff_id"] = values["staff_id"] + sess_info = _create_session_dict( + window, + log_task, + values["staff_id"], + subject, + tasks, + ) + # Open new layout with main window + window = _win_gen(_main_layout, sess_info) + _start_ctr_server(window, logger) + logger.debug(f"ctr msg reader started") + + ############################################################ + # Main Window -> Run neurobooth session + ############################################################ + + # Start servers on STM, ACQ + elif event == "-init_servs-": + _start_servers(window, nodes) + + # Turn on devices + elif event == "-Connect-": + video_marker_stream, event, values = _prepare_devices(window, + nodes, collection_id, log_task, database, tasks, conn) + + elif event == "plot": + _plot_realtime(window, plttr, inlets) + + elif event == "Start": + if not start_pressed: + window["Start"].Update(disabled=True) + start_pressed = True + session_id = meta._make_session_id(conn, log_sess) + tasks = [k for k, v in values.items() if "obs" in k and v is True] + _start_task_presentation(window, tasks, sess_info["subject_id"], session_id, steps, conn) + + elif event == "tasks_created": + _session_button_state(window, disabled=False) + for task_id in tasks: + msg_body = PerformTaskRequest(task_id=task_id) + msg = Request(source="CTR", destination="STM", body=msg_body) + meta.post_message(msg, conn) + # PerformTask Messages queued for all tasks, now queue a TasksFinished message + msg_body = TasksFinished() msg = Request(source="CTR", destination="STM", body=msg_body) meta.post_message(msg, conn) - # PerformTask Messages queued for all tasks, now queue a TasksFinished message - msg_body = TasksFinished() - msg = Request(source="CTR", destination="STM", body=msg_body) - meta.post_message(msg, conn) - - elif event == "Pause tasks": - _pause_tasks(window, steps, conn=conn) - - elif event == "Stop tasks": - _stop_task_dialog(window, conn=conn, resume_on_cancel=False) - - elif event == "Calibrate": - _calibrate(window, conn=conn) - - # Save notes to a txt - elif event == "_save_notes_": - if values["_notes_taskname_"] != "": - _save_session_notes(sess_info, values, window) - else: - sg.PopupError( - "Pressed save notes without task, select one in the dropdown list", - location=get_popup_location(window) - ) - continue - elif event == sg.WIN_CLOSED: - break - - # Shut down the other servers and stops plotting - elif event == "Shut Down" or event == sg.WINDOW_CLOSE_ATTEMPTED_EVENT: - if (values is not None - and ('notes' in values and values['notes'] != '') - and ("_notes_taskname_" not in values or values['_notes_taskname_'] == '')): - sg.PopupError( - "Unsaved notes without task. Before exiting, " - "select a task in the dropdown list or delete the note text.", - location=get_popup_location(window) + elif event == "Pause tasks": + _pause_tasks(window, steps, conn=conn) + + elif event == "Stop tasks": + _stop_task_dialog(window, conn=conn, resume_on_cancel=False) + + elif event == "Calibrate": + _calibrate(window, conn=conn) + + # Save notes to a txt + elif event == "_save_notes_": + if values["_notes_taskname_"] != "": + _save_session_notes(sess_info, values, window) + else: + sg.PopupError( + "Pressed save notes without task, select one in the dropdown list", + location=get_popup_location(window) + ) + continue + + elif event == sg.WIN_CLOSED: + break + + # Shut down the other servers and stops plotting + elif event == "Shut Down" or event == sg.WINDOW_CLOSE_ATTEMPTED_EVENT: + if (values is not None + and ('notes' in values and values['notes'] != '') + and ("_notes_taskname_" not in values or values['_notes_taskname_'] == '')): + sg.PopupError( + "Unsaved notes without task. Before exiting, " + "select a task in the dropdown list or delete the note text.", + location=get_popup_location(window) + ) + continue + else: + response = sg.popup_ok_cancel("System will terminate! \n\n" + "Please ensure that any task in progress is completed and that STM and " + "ACQ shut down properly.\n", title="Warning", + location=get_popup_location(window)) + if response == "OK": + write_output(window, "System termination scheduled. " + "Servers will shut down after the current task.") + + terminate_system(conn, plttr, sess_info, values, window) + break + + ################################################################################## + # Thread events from process_received_data -> received messages from other servers + ################################################################################## + + # Signal a task started: record LSL data and update gui + elif event == "task_initiated": + # event values -> f"['{task_id}', '{t_obs_id}', '{log_task_id}, '{tsk_strt_time}'] + window["-frame_preview-"].update(visible=False) + task_id, t_obs_id, obs_log_id, tsk_strt_time = eval(values[event]) + write_output(window, f"\nTask initiated: {task_id}") + + logger.debug(f"Starting LSL for task: {t_obs_id}") + rec_fname = _record_lsl( + window, + session, + sess_info["subject_id_date"], + task_id, + t_obs_id, + obs_log_id, + tsk_strt_time, + conn ) - continue - else: - response = sg.popup_ok_cancel("System will terminate! \n\n" - "Please ensure that any task in progress is completed and that STM and " - "ACQ shut down properly.\n", title="Warning", - location=get_popup_location(window)) - if response == "OK": - write_output(window, "System termination scheduled. " - "Servers will shut down after the current task.") - - terminate_system(conn, plttr, sess_info, values, window) - break - - ################################################################################## - # Thread events from process_received_data -> received messages from other servers - ################################################################################## - - # Signal a task started: record LSL data and update gui - elif event == "task_initiated": - # event values -> f"['{task_id}', '{t_obs_id}', '{log_task_id}, '{tsk_strt_time}'] - window["-frame_preview-"].update(visible=False) - task_id, t_obs_id, obs_log_id, tsk_strt_time = eval(values[event]) - write_output(window, f"\nTask initiated: {task_id}") - - logger.debug(f"Starting LSL for task: {t_obs_id}") - rec_fname = _record_lsl( - window, - session, - sess_info["subject_id_date"], - task_id, - t_obs_id, - obs_log_id, - tsk_strt_time, - conn - ) - - # Signal a task ended: stop LSL recording and update gui - elif event == "task_finished": - task_id, has_lsl_stream = eval(values['task_finished']) - boolean_value = has_lsl_stream.lower() == 'true' - if boolean_value: - logger.debug(f"Stopping LSL for task: {task_id}") - handle_task_finished(conn, obs_log_id, rec_fname, sess_info, session, task_id, values, window) - if task_id == last_task: - _session_button_state(window, disabled=True) - write_output(window, "\nSession complete: OK to terminate", 'blue') - - # Send a marker string with the name of the new video file created - elif event == "-new_filename-": - video_marker_stream.push_sample([values[event]]) - - elif event == 'devices_connected': - global session_prepared_count - session_prepared_count += 1 - if session_prepared_count == len(_get_nodes()): - session = _start_lsl_session(window, inlets, sess_info["subject_id_date"]) - window["-frame_preview-"].update(visible=True) - if not start_pressed: - window['Start'].update(disabled=False) - write_output(window, "Device connection complete. OK to start session") - - # Create LSL inlet stream - elif event == "-OUTLETID-": - _create_lsl_inlet(stream_ids, values[event], inlets) - elif event == "server_started": - server = values[event] - write_output(window, f"{server} server started") - - if server == "ACQ": - node_name = "acquisition" - elif server == "STM": - node_name = "presentation" - else: - raise RuntimeError(f"Unknown server type: {server} as source of ServerStarted message") - - running_servers.append(node_name) - expected_servers = _get_nodes() - check = all(e in running_servers for e in expected_servers) - if check: - write_output(window, "Servers initiated. OK to connect devices.") - window["-Connect-"].Update(disabled=False) - - elif event == "no_eyetracker": - result = sg.PopupError(values[event], location=get_popup_location(window)) - if result == 'Error': - window.write_event_value("Shut Down", "Shut Down") - - elif event == "mbient_disconnected": - sg.PopupError(values[event], non_blocking=True, location=get_popup_location(window)) - - ################################################################################## - # Conditionals handling inlets for plotting and recording - ################################################################################## - - elif event == "-frame_preview-": - _request_frame_preview(conn) - - # Print LSL inlet names in GUI - if inlet_keys != list(inlets): - inlet_keys = list(inlets) - window["inlet_State"].update("\n".join(inlet_keys)) + # Signal a task ended: stop LSL recording and update gui + elif event == "task_finished": + task_id, has_lsl_stream = eval(values['task_finished']) + boolean_value = has_lsl_stream.lower() == 'true' + if boolean_value: + logger.debug(f"Stopping LSL for task: {task_id}") + handle_task_finished(conn, obs_log_id, rec_fname, sess_info, session, task_id, values, window) + if task_id == last_task: + _session_button_state(window, disabled=True) + write_output(window, "\nSession complete: OK to terminate", 'blue') + + # Send a marker string with the name of the new video file created + elif event == "-new_filename-": + video_marker_stream.push_sample([values[event]]) + + elif event == 'devices_connected': + global session_prepared_count + session_prepared_count += 1 + if session_prepared_count == len(_get_nodes()): + session = _start_lsl_session(window, inlets, sess_info["subject_id_date"]) + window["-frame_preview-"].update(visible=True) + if not start_pressed: + window['Start'].update(disabled=False) + write_output(window, "Device connection complete. OK to start session") + + # Create LSL inlet stream + elif event == "-OUTLETID-": + _create_lsl_inlet(stream_ids, values[event], inlets) + + elif event == "server_started": + server = values[event] + write_output(window, f"{server} server started") + + if server == "ACQ": + node_name = "acquisition" + elif server == "STM": + node_name = "presentation" + else: + raise RuntimeError(f"Unknown server type: {server} as source of ServerStarted message") + + running_servers.append(node_name) + expected_servers = _get_nodes() + check = all(e in running_servers for e in expected_servers) + if check: + write_output(window, "Servers initiated. OK to connect devices.") + window["-Connect-"].Update(disabled=False) + + elif event == "no_eyetracker": + result = sg.PopupError(values[event], location=get_popup_location(window)) + if result == 'Error': + window.write_event_value("Shut Down", "Shut Down") + + elif event == "mbient_disconnected": + sg.PopupError(values[event], non_blocking=True, location=get_popup_location(window)) + + ################################################################################## + # Conditionals handling inlets for plotting and recording + ################################################################################## + + elif event == "-frame_preview-": + _request_frame_preview(conn) + + # Print LSL inlet names in GUI + if inlet_keys != list(inlets): + inlet_keys = list(inlets) + window["inlet_State"].update("\n".join(inlet_keys)) close(window)