diff --git a/blacksheep/server/application.py b/blacksheep/server/application.py index 7295dd35..34ba8feb 100644 --- a/blacksheep/server/application.py +++ b/blacksheep/server/application.py @@ -64,7 +64,7 @@ ) from blacksheep.server.routing import router as default_router from blacksheep.server.routing import validate_router -from blacksheep.server.websocket import WebSocket +from blacksheep.server.websocket import WebSocket, format_reason from blacksheep.sessions import SessionMiddleware, SessionSerializer from blacksheep.settings.di import di_settings from blacksheep.utils import ensure_bytes, join_fragments @@ -749,7 +749,7 @@ async def _handle_websocket(self, scope, receive, send): # If WebSocket connection accepted, close # the connection using WebSocket Internal error code. if ws.accepted: - return await ws.close(1011, reason=str(exc)) + return await ws.close(1011, reason=format_reason(str(exc))) # Otherwise, just close the connection, the ASGI server # will anyway respond 403 to the client. diff --git a/blacksheep/server/websocket.py b/blacksheep/server/websocket.py index 7fa12a0d..d0d3deb4 100644 --- a/blacksheep/server/websocket.py +++ b/blacksheep/server/websocket.py @@ -6,6 +6,8 @@ from blacksheep.server.asgi import get_full_path from blacksheep.settings.json import json_settings +MAX_REASON_SIZE = 123 + class WebSocketState(Enum): CONNECTING = 0 @@ -180,3 +182,18 @@ async def disconnect(): async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: await self._send({"type": "websocket.close", "code": code, "reason": reason}) + + +def format_reason(reason: str) -> str: + """ + Ensures that the close reason is no longer than the max reason size. + (https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#reason) + """ + reason_bytes = reason.encode() + + if len(reason_bytes) <= MAX_REASON_SIZE: + return reason + + ellipsis_ = b"..." + truncated_reason = reason_bytes[: MAX_REASON_SIZE - len(ellipsis_)] + ellipsis_ + return truncated_reason.decode() diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 4b95a387..2f113288 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -7,6 +7,7 @@ WebSocket, WebSocketDisconnectError, WebSocketState, + format_reason, ) from blacksheep.testing.messages import MockReceive, MockSend from tests.utils.application import FakeApplication @@ -442,3 +443,20 @@ async def websocket_handler(my_ws: WebSocket): mock_receive, mock_send, ) + +LONG_REASON = "WRY" * 41 +QIN = "秦" # Qyn dynasty in Chinese, 3 bytes. +TOO_LONG_REASON = QIN * 42 +TOO_LONG_REASON_TRUNC = TOO_LONG_REASON[:40] + "..." + + +@pytest.mark.parametrize( + "inp,out", + [ + ("Short reason", "Short reason"), + (LONG_REASON, LONG_REASON), + (TOO_LONG_REASON, TOO_LONG_REASON_TRUNC) + ] +) +def test_format_reason(inp, out): + assert format_reason(inp) == out