Skip to content

Commit

Permalink
Buffer messages to wait for reconnect
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas ONeil <[email protected]>
  • Loading branch information
loneil committed Dec 20, 2024
1 parent f6c4c6c commit a27342c
Showing 1 changed file with 61 additions and 13 deletions.
74 changes: 61 additions & 13 deletions oidc-controller/api/routers/socketio.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,84 @@
import socketio # For using websockets
import logging
import time

logger = logging.getLogger(__name__)


connections = {}
message_buffers = {}
buffer_timeout = 60 # Timeout in seconds

sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")

sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io")


@sio.event
async def connect(sid, socket):
async def connect(sid):
logger.info(f">>> connect : sid={sid}")


@sio.event
async def initialize(sid, data):
global connections
# Store websocket session matched to the presentation exchange id
connections[data.get("pid")] = sid

global connections, message_buffers
pid = data.get("pid")
connections[pid] = sid
# Initialize buffer if it doesn't exist
if pid not in message_buffers:
message_buffers[pid] = []

@sio.event
async def disconnect(sid):
global connections
global connections, message_buffers
logger.info(f">>> disconnect : sid={sid}")
# Remove websocket session from the store
if len(connections) > 0:
connections = {k: v for k, v in connections.items() if v != sid}
# Find the pid associated with the sid
pid = next((k for k, v in connections.items() if v == sid), None)
if pid:
# Remove pid from connections
del connections[pid]

@sio.event
async def send_message(sid, data):
global connections, message_buffers
pid = data.get("pid")
if pid in connections:
target_sid = connections[pid]
try:
await sio.emit('message', data['message'], room=target_sid)
except:
# If send fails, buffer the message
buffer_message(pid, data['message'])
else:
# Buffer the message if pid is not connected
buffer_message(pid, data['message'])

def buffer_message(pid, message):
global message_buffers
current_time = time.time()
if pid not in message_buffers:
message_buffers[pid] = []
# Add message with timestamp
message_buffers[pid].append((message, current_time))
# Clean up old messages
message_buffers[pid] = [
(msg, timestamp) for msg, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]

@sio.event
async def fetch_buffered_messages(sid, data):
pid = data.get("pid")
global message_buffers
current_time = time.time()
if pid in message_buffers:
valid_messages = [
msg for msg, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]
for message in valid_messages:
await sio.emit('message', message, room=sid)
# Clean up messages after sending
message_buffers[pid] = [
(msg, timestamp) for msg, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]

def connections_reload():
global connections
Expand Down

0 comments on commit a27342c

Please sign in to comment.