From 5ab0f69fb665672d556932f8e9b937d370a2f941 Mon Sep 17 00:00:00 2001 From: Lucas ONeil Date: Fri, 20 Dec 2024 15:22:31 -0800 Subject: [PATCH] Refactor Signed-off-by: Lucas ONeil --- oidc-controller/api/routers/acapy_handler.py | 17 ++----- oidc-controller/api/routers/oidc.py | 7 +-- .../api/routers/presentation_request.py | 9 +--- oidc-controller/api/routers/socketio.py | 48 +++++++++---------- .../api/templates/verified_credentials.html | 10 ++++ 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/oidc-controller/api/routers/acapy_handler.py b/oidc-controller/api/routers/acapy_handler.py index a4a8a534..74714ac8 100644 --- a/oidc-controller/api/routers/acapy_handler.py +++ b/oidc-controller/api/routers/acapy_handler.py @@ -11,7 +11,7 @@ from ..db.session import get_db from ..core.config import settings -from ..routers.socketio import sio, connections_reload +from ..routers.socketio import buffered_emit, connections_reload logger: structlog.typing.FilteringBoundLogger = structlog.getLogger(__name__) @@ -39,9 +39,6 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db # Get the saved websocket session pid = str(auth_session.id) - connections = connections_reload() - sid = connections.get(pid) - logger.debug(f"sid: {sid} found for pid: {pid}") if webhook_body["state"] == "presentation-received": logger.info("presentation-received") @@ -51,12 +48,10 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db if webhook_body["verified"] == "true": auth_session.proof_status = AuthSessionState.VERIFIED auth_session.presentation_exchange = webhook_body["by_format"] - if sid: - await sio.emit("status", {"status": "verified"}, to=sid) + await buffered_emit("status", {"status": "verified"}, to_pid=pid) else: auth_session.proof_status = AuthSessionState.FAILED - if sid: - await sio.emit("status", {"status": "failed"}, to=sid) + await buffered_emit("status", {"status": "failed"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) @@ -67,8 +62,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db logger.info("ABANDONED") logger.info(webhook_body["error_msg"]) auth_session.proof_status = AuthSessionState.ABANDONED - if sid: - await sio.emit("status", {"status": "abandoned"}, to=sid) + await buffered_emit("status", {"status": "abandoned"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) @@ -93,8 +87,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db ): logger.info("EXPIRED") auth_session.proof_status = AuthSessionState.EXPIRED - if sid: - await sio.emit("status", {"status": "expired"}, to=sid) + await buffered_emit("status", {"status": "expired"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) diff --git a/oidc-controller/api/routers/oidc.py b/oidc-controller/api/routers/oidc.py index 2d4d7073..84b589d5 100644 --- a/oidc-controller/api/routers/oidc.py +++ b/oidc-controller/api/routers/oidc.py @@ -31,7 +31,7 @@ from ..db.session import get_db # Access to the websocket -from ..routers.socketio import connections_reload, sio +from ..routers.socketio import buffered_emit, connections_reload from ..verificationConfigs.crud import VerificationConfigCRUD from ..verificationConfigs.helpers import VariableSubstitutionError @@ -58,8 +58,6 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)): auth_session = await AuthSessionCRUD(db).get(pid) pid = str(auth_session.id) - connections = connections_reload() - sid = connections.get(pid) """ Check if proof is expired. But only if the proof has not been started. @@ -75,8 +73,7 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)): str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) ) # Send message through the websocket. - if sid: - await sio.emit("status", {"status": "expired"}, to=sid) + await buffered_emit("status", {"status": "expired"}, to_pid=pid) return {"proof_status": auth_session.proof_status} diff --git a/oidc-controller/api/routers/presentation_request.py b/oidc-controller/api/routers/presentation_request.py index 9bc753bb..8331efeb 100644 --- a/oidc-controller/api/routers/presentation_request.py +++ b/oidc-controller/api/routers/presentation_request.py @@ -9,7 +9,7 @@ from ..authSessions.models import AuthSession, AuthSessionState from ..core.config import settings -from ..routers.socketio import sio, connections_reload +from ..routers.socketio import buffered_emit, connections_reload from ..routers.oidc import gen_deep_link from ..db.session import get_db @@ -49,16 +49,11 @@ async def send_connectionless_proof_req( pres_exch_id ) - # Get the websocket session - connections = connections_reload() - sid = connections.get(str(auth_session.id)) - # If the qrcode has been scanned, toggle the verified flag if auth_session.proof_status is AuthSessionState.NOT_STARTED: auth_session.proof_status = AuthSessionState.PENDING await AuthSessionCRUD(db).patch(auth_session.id, auth_session) - if sid: - await sio.emit("status", {"status": "pending"}, to=sid) + await buffered_emit("status", {"status": "pending"}, to_pid=auth_session.id) msg = auth_session.presentation_request_msg diff --git a/oidc-controller/api/routers/socketio.py b/oidc-controller/api/routers/socketio.py index 0a67c943..53898519 100644 --- a/oidc-controller/api/routers/socketio.py +++ b/oidc-controller/api/routers/socketio.py @@ -12,7 +12,7 @@ sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io") @sio.event -async def connect(sid): +async def connect(sid, socket): logger.info(f">>> connect : sid={sid}") @sio.event @@ -34,51 +34,51 @@ async def disconnect(sid): # Remove pid from connections del connections[pid] -@sio.event -async def send_message(sid, data): +async def buffered_emit(event, data, to_pid=None): global connections, message_buffers - pid = data.get("pid") - if pid in connections: - target_sid = connections[pid] + + connections = connections_reload() + sid = connections.get(to_pid) + logger.debug(f"sid: {sid} found for pid: {to_pid}") + + if sid: try: - await sio.emit('message', data['message'], room=target_sid) + await sio.emit(event, data, room=sid) except: # If send fails, buffer the message - buffer_message(pid, data['message']) + buffer_message(to_pid, event, data) else: - # Buffer the message if pid is not connected - buffer_message(pid, data['message']) + # Buffer the message if the target is not connected + buffer_message(to_pid, event, data) -def buffer_message(pid, message): +def buffer_message(pid, event, data): global message_buffers current_time = time.time() if pid not in message_buffers: message_buffers[pid] = [] - # Add message with timestamp - message_buffers[pid].append((message, current_time)) + # Add message with timestamp and event name + message_buffers[pid].append((event, data, current_time)) # Clean up old messages message_buffers[pid] = [ - (msg, timestamp) for msg, timestamp in message_buffers[pid] + (msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid] if current_time - timestamp <= buffer_timeout ] @sio.event -async def fetch_buffered_messages(sid, data): - pid = data.get("pid") +async def fetch_buffered_messages(sid, pid): global message_buffers current_time = time.time() if pid in message_buffers: + # Filter messages that are still valid (i.e., within the buffer_timeout) valid_messages = [ - msg for msg, timestamp in message_buffers[pid] - if current_time - timestamp <= buffer_timeout - ] - for message in valid_messages: - await sio.emit('message', message, room=sid) - # Clean up messages after sending - message_buffers[pid] = [ - (msg, timestamp) for msg, timestamp in message_buffers[pid] + (msg_event, msg_data, timestamp) for msg_event, msg_data, timestamp in message_buffers[pid] if current_time - timestamp <= buffer_timeout ] + # Emit each valid message + for event, data, _ in valid_messages: + await sio.emit(event, data, room=sid) + # Reassign the valid_messages back to message_buffers[pid] to clean up old messages + message_buffers[pid] = valid_messages def connections_reload(): global connections diff --git a/oidc-controller/api/templates/verified_credentials.html b/oidc-controller/api/templates/verified_credentials.html index 474d560d..97cf2c35 100644 --- a/oidc-controller/api/templates/verified_credentials.html +++ b/oidc-controller/api/templates/verified_credentials.html @@ -112,6 +112,14 @@

Continue with:

> DEBUG Disconnect Web Socket + +
@@ -383,6 +391,8 @@
`Socket connecting. SID: ${this.socket.id}. PID: {{pid}}. Recovered? ${this.socket.recovered} ` ); this.socket.emit("initialize", { pid: "{{pid}}" }); + // Emit the `fetch_buffered_messages` event with `pid` as a string using Jinja templating + this.socket.emit('fetch_buffered_messages', '{{ pid }}'); }); this.socket.on("connect_error", (error) => {