Skip to content

Commit

Permalink
Fix WS exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Klavionik authored Dec 15, 2023
1 parent 5f9bc45 commit 6936f1f
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 32 deletions.
45 changes: 18 additions & 27 deletions blacksheep/server/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from blacksheep.common import extend
from blacksheep.common.files.asyncfs import FilesHandler
from blacksheep.contents import ASGIContent
from blacksheep.exceptions import HTTPException
from blacksheep.messages import Request, Response
from blacksheep.middlewares import get_middlewares_chain
from blacksheep.scribe import send_asgi_response
Expand Down Expand Up @@ -65,7 +64,7 @@
)
from blacksheep.server.routing import router as default_router
from blacksheep.server.routing import validate_router
from blacksheep.server.websocket import WebSocket
from blacksheep.server.websocket import WebSocket, format_reason
from blacksheep.sessions import SessionMiddleware, SessionSerializer
from blacksheep.settings.di import di_settings
from blacksheep.utils import ensure_bytes, join_fragments
Expand Down Expand Up @@ -739,31 +738,23 @@ async def _handle_websocket(self, scope, receive, send):
RouteMethod.GET_WS, scope["path"]
)

if route:
ws.route_values = route.values
try:
return await route.handler(ws)
except UnauthorizedError as unauthorized_error:
# If the WebSocket connection was not accepted yet, we close the
# connection with an HTTP Status Code, otherwise we close the connection
# with a WebSocket status code
if ws.accepted:
# Use a WebSocket error code, not an HTTP error code
await ws.close(1005, "Unauthorized")
else:
# Still in handshake phase, we close with an HTTP Status Code
# https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
await ws.close(403, str(unauthorized_error))
except HTTPException as http_exception:
# Same like above
if ws.accepted:
# Use a WebSocket error code, not an HTTP error code
await ws.close(1005, str(http_exception))
else:
# Still in handshake phase, we close with an HTTP Status Code
# https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
await ws.close(http_exception.status, str(http_exception))
await ws.close()
if route is None:
return await ws.close()

ws.route_values = route.values

try:
return await route.handler(ws)
except Exception as exc:
logging.exception("Exception while handling WebSocket")
# If WebSocket connection accepted, close
# the connection using WebSocket Internal error code.
if ws.accepted:
return await ws.close(1011, reason=format_reason(str(exc)))

# Otherwise, just close the connection, the ASGI server
# will anyway respond 403 to the client.
return await ws.close()

async def _handle_http(self, scope, receive, send):
assert scope["type"] == "http"
Expand Down
17 changes: 17 additions & 0 deletions blacksheep/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from blacksheep.server.asgi import get_full_path
from blacksheep.settings.json import json_settings

MAX_REASON_SIZE = 123


class WebSocketState(Enum):
CONNECTING = 0
Expand Down Expand Up @@ -180,3 +182,18 @@ async def disconnect():

async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
await self._send({"type": "websocket.close", "code": code, "reason": reason})


def format_reason(reason: str) -> str:
"""
Ensures that the close reason is no longer than the max reason size.
(https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/close#reason)
"""
reason_bytes = reason.encode()

if len(reason_bytes) <= MAX_REASON_SIZE:
return reason

ellipsis_ = b"..."
truncated_reason = reason_bytes[: MAX_REASON_SIZE - len(ellipsis_)] + ellipsis_
return truncated_reason.decode()
11 changes: 8 additions & 3 deletions itests/app_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic import BaseModel

from blacksheep import Response, TextContent, WebSocket
from blacksheep.exceptions import BadRequest
from blacksheep.server import Application
from blacksheep.server.authentication import AuthenticationHandler
from blacksheep.server.authorization import Policy, Requirement, auth
Expand Down Expand Up @@ -150,9 +149,15 @@ async def echo_text_admin_users(websocket: WebSocket):
await websocket.send_text(msg)


@app_2.router.ws("/websocket-echo-text-http-exp")
@app_2.router.ws("/websocket-error-before-accept")
async def echo_text_http_exp(websocket: WebSocket):
raise BadRequest("Example")
raise RuntimeError("Error before accept")


@app_2.router.ws("/websocket-server-error")
async def websocket_server_error(websocket: WebSocket):
await websocket.accept()
raise RuntimeError("Server error")


@auth("authenticated")
Expand Down
17 changes: 15 additions & 2 deletions itests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import websockets
import yaml
from websockets.exceptions import InvalidStatusCode
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode

from .client_fixtures import get_static_path
from .server_fixtures import * # NoQA
Expand Down Expand Up @@ -793,7 +793,7 @@ async def test_websocket(server_host, server_port_4, route, data):
"route",
[
"websocket-echo-text-auth",
"websocket-echo-text-http-exp",
"websocket-error-before-accept",
],
)
async def test_websocket_auth(server_host, server_port_2, route):
Expand All @@ -805,3 +805,16 @@ async def test_websocket_auth(server_host, server_port_2, route):

assert error.value.status_code == 403
assert "server rejected" in str(error.value)


@pytest.mark.asyncio
async def test_websocket_server_error(server_host, server_port_2):
uri = f"ws://{server_host}:{server_port_2}/websocket-server-error"

with pytest.raises(ConnectionClosedError) as error:
async with websockets.connect(uri) as ws:
async for _message in ws:
pass

assert error.value.code == 1011
assert error.value.reason == "Server error"
19 changes: 19 additions & 0 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WebSocket,
WebSocketDisconnectError,
WebSocketState,
format_reason,
)
from blacksheep.testing.messages import MockReceive, MockSend
from tests.utils.application import FakeApplication
Expand Down Expand Up @@ -442,3 +443,21 @@ async def websocket_handler(my_ws: WebSocket):
mock_receive,
mock_send,
)


LONG_REASON = "WRY" * 41
QIN = "秦" # Qyn dynasty in Chinese, 3 bytes.
TOO_LONG_REASON = QIN * 42
TOO_LONG_REASON_TRUNC = TOO_LONG_REASON[:40] + "..."


@pytest.mark.parametrize(
"inp,out",
[
("Short reason", "Short reason"),
(LONG_REASON, LONG_REASON),
(TOO_LONG_REASON, TOO_LONG_REASON_TRUNC),
],
)
def test_format_reason(inp, out):
assert format_reason(inp) == out

0 comments on commit 6936f1f

Please sign in to comment.