Skip to content

Commit

Permalink
Rework how services are started/stopped (jupyter-server#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Apr 24, 2024
1 parent 0773726 commit 3f9cd15
Show file tree
Hide file tree
Showing 10 changed files with 439 additions and 326 deletions.
90 changes: 51 additions & 39 deletions pycrdt_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from anyio import (
TASK_STATUS_IGNORED,
Event,
Lock,
create_memory_object_stream,
create_task_group,
)
from anyio.abc import TaskGroup, TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pycrdt import Doc
from pycrdt import Doc, Subscription

from .websocket import Websocket
from .yutils import (
Expand All @@ -30,9 +31,10 @@ class WebsocketProvider:
_ydoc: Doc
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_started: Event | None
_starting: bool
_task_group: TaskGroup | None
_subscription: Subscription
_started: Event | None = None
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None:
"""Initialize the object.
Expand All @@ -47,7 +49,7 @@ def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -
task = asyncio.create_task(websocket_provider.start())
await websocket_provider.started.wait()
...
websocket_provider.stop()
await websocket_provider.stop()
```
Arguments:
Expand All @@ -61,10 +63,6 @@ def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self._started = None
self._starting = False
self._task_group = None
ydoc.observe(partial(put_updates, self._update_send_stream))

@property
def started(self) -> Event:
Expand All @@ -73,26 +71,13 @@ def started(self) -> Event:
self._started = Event()
return self._started

async def __aenter__(self) -> WebsocketProvider:
if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
tg.start_soon(self._run)
self.started.set()

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if self._task_group is None:
raise RuntimeError("WebsocketProvider not running")

self._task_group.cancel_scope.cancel()
self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)
@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
self.__start_lock = Lock()
return self.__start_lock

async def _run(self):
await sync(self._ydoc, self._websocket, self.log)
Expand All @@ -110,30 +95,57 @@ async def _send(self):
except Exception:
pass

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async def __aenter__(self) -> WebsocketProvider:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(
self,
*,
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
from_context_manager: bool = False,
):
"""Start the WebSocket provider.
Arguments:
task_status: The status to set when the task has started.
"""
if self._starting:
self._subscription = self._ydoc.observe(partial(put_updates, self._update_send_stream))

if from_context_manager:
task_status.started()
self.started.set()
assert self._task_group is not None
self._task_group.start_soon(self._run)
return
else:
self._starting = True

if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")

async with create_task_group() as self._task_group:
self._task_group.start_soon(self._run)
self.started.set()
self._starting = False
task_status.started()
async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
self._task_group.start_soon(self._run)

def stop(self):
async def stop(self):
"""Stop the WebSocket provider."""
if self._task_group is None:
raise RuntimeError("WebsocketProvider not running")

self._task_group.cancel_scope.cancel()
self._task_group = None
self._ydoc.unobserve(self._subscription)
83 changes: 46 additions & 37 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from functools import partial
from logging import Logger, getLogger

from anyio import TASK_STATUS_IGNORED, Event, create_task_group
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus

from .websocket import Websocket
Expand All @@ -15,9 +16,9 @@ class WebsocketServer:

auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None
_starting: bool
_task_group: TaskGroup | None
_started: Event | None = None
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(
self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None
Expand All @@ -34,7 +35,7 @@ def __init__(
task = asyncio.create_task(websocket_server.start())
await websocket_server.started.wait()
...
websocket_server.stop()
await websocket_server.stop()
```
Arguments:
Expand All @@ -46,9 +47,6 @@ def __init__(
self.auto_clean_rooms = auto_clean_rooms
self.log = log or getLogger(__name__)
self.rooms = {}
self._started = None
self._starting = False
self._task_group = None

@property
def started(self) -> Event:
Expand All @@ -57,6 +55,12 @@ def started(self) -> Event:
self._started = Event()
return self._started

@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
self.__start_lock = Lock()
return self.__start_lock

async def get_room(self, name: str) -> YRoom:
"""Get or create a room with the given name, and start it.
Expand Down Expand Up @@ -115,20 +119,20 @@ def rename_room(
from_name = self.get_room_name(from_room)
self.rooms[to_name] = self.rooms.pop(from_name)

def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> None:
async def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> None:
"""Delete a room.
Arguments:
name: The name of the room to delete (if `room` is not passed).
room: The room to delete ( if `name` is not passed).
room: The room to delete (if `name` is not passed).
"""
if name is not None and room is not None:
raise RuntimeError("Cannot pass name and room")
if name is None:
assert room is not None
name = self.get_room_name(room)
room = self.rooms.pop(name)
room.stop()
await room.stop()

async def serve(self, websocket: Websocket) -> None:
"""Serve a client through a WebSocket.
Expand All @@ -151,51 +155,56 @@ async def _serve(self, websocket: Websocket, tg: TaskGroup):
await room.serve(websocket)

if self.auto_clean_rooms and not room.clients:
self.delete_room(room=room)
await self.delete_room(room=room)
tg.cancel_scope.cancel()

async def __aenter__(self) -> WebsocketServer:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
self.started.set()
async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
if self._task_group is None:
raise RuntimeError("WebsocketServer not running")

self._task_group.cancel_scope.cancel()
self._task_group = None
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async def start(
self,
*,
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
from_context_manager: bool = False,
):
"""Start the WebSocket server.
Arguments:
task_status: The status to set when the task has started.
"""
if self._starting:
if from_context_manager:
task_status.started()
self.started.set()
assert self._task_group is not None
# wait forever
self._task_group.start_soon(Event().wait)
return
else:
self._starting = True

if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

# create the task group and wait forever
async with create_task_group() as self._task_group:
self._task_group.start_soon(Event().wait)
self.started.set()
self._starting = False
task_status.started()
async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
# wait forever
self._task_group.start_soon(Event().wait)

def stop(self) -> None:
async def stop(self) -> None:
"""Stop the WebSocket server."""
if self._task_group is None:
raise RuntimeError("WebsocketServer not running")
Expand Down
Loading

0 comments on commit 3f9cd15

Please sign in to comment.