diff --git a/src/py/reactpy/pyproject.toml b/src/py/reactpy/pyproject.toml index 67189808b..309248507 100644 --- a/src/py/reactpy/pyproject.toml +++ b/src/py/reactpy/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "exceptiongroup >=1.0", "typing-extensions >=3.10", "mypy-extensions >=0.4.3", "anyio >=3", diff --git a/src/py/reactpy/reactpy/backend/starlette.py b/src/py/reactpy/reactpy/backend/starlette.py index 2953b97b3..cd1b5b7c6 100644 --- a/src/py/reactpy/reactpy/backend/starlette.py +++ b/src/py/reactpy/reactpy/backend/starlette.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Callable +from exceptiongroup import BaseExceptionGroup from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request @@ -137,8 +138,6 @@ async def serve_index(request: Request) -> HTMLResponse: def _setup_single_view_dispatcher_route( options: Options, app: Starlette, component: RootComponentConstructor ) -> None: - @app.websocket_route(str(STREAM_PATH)) - @app.websocket_route(f"{STREAM_PATH}/{{path:path}}") async def model_stream(socket: WebSocket) -> None: await socket.accept() send, recv = _make_send_recv_callbacks(socket) @@ -162,8 +161,16 @@ async def model_stream(socket: WebSocket) -> None: send, recv, ) - except WebSocketDisconnect as error: - logger.info(f"WebSocket disconnect: {error.code}") + except BaseExceptionGroup as egroup: + for e in egroup.exceptions: + if isinstance(e, WebSocketDisconnect): + logger.info(f"WebSocket disconnect: {e.code}") + break + else: + raise + + app.add_websocket_route(str(STREAM_PATH), model_stream) + app.add_websocket_route(f"{STREAM_PATH}/{{path:path}}", model_stream) def _make_send_recv_callbacks(