From 7f9081591a19dee5142ad8f9da530d8f71ff55b2 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Sat, 9 Nov 2024 20:52:32 +0100 Subject: [PATCH] Use ThreadSelectorEventLoop on Windows with ProactorEventLoop --- pyproject.toml | 1 + src/zmq_anyio/_selector_thread.py | 387 ++++++++++++++++++++++++++++++ src/zmq_anyio/_socket.py | 51 ++-- tests/test_socket.py | 7 +- 4 files changed, 412 insertions(+), 34 deletions(-) create mode 100644 src/zmq_anyio/_selector_thread.py diff --git a/pyproject.toml b/pyproject.toml index 889488a..de7718a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ requires-python = ">= 3.9" dependencies = [ "anyio", + "sniffio", "anyioutils >=0.4.6", "pyzmq >=26.0.0,<27.0.0", ] diff --git a/src/zmq_anyio/_selector_thread.py b/src/zmq_anyio/_selector_thread.py new file mode 100644 index 0000000..68bca92 --- /dev/null +++ b/src/zmq_anyio/_selector_thread.py @@ -0,0 +1,387 @@ +"""Ensure asyncio selector methods (add_reader, etc.) are available. +Running select in a thread and defining these methods on the running event loop. +Originally in tornado.platform.asyncio. +Redistributed under license Apache-2.0 +""" + +from __future__ import annotations + +import asyncio +import atexit +import errno +import functools +import select +import socket +import sys +import threading +import typing +from typing import ( + Any, + Callable, + Union, +) +from weakref import WeakKeyDictionary + +from sniffio import current_async_library + +if typing.TYPE_CHECKING: + from typing_extensions import Protocol + + class _HasFileno(Protocol): + def fileno(self) -> int: + pass + + _FileDescriptorLike = Union[int, _HasFileno] + + +# Collection of selector thread event loops to shut down on exit. +_selector_loops: set[SelectorThread] = set() + + +def _atexit_callback() -> None: + for loop in _selector_loops: + with loop._select_cond: + loop._closing_selector = True + loop._select_cond.notify() + try: + loop._waker_w.send(b"a") + except BlockingIOError: + pass + # If we don't join our (daemon) thread here, we may get a deadlock + # during interpreter shutdown. I don't really understand why. This + # deadlock happens every time in CI (both travis and appveyor) but + # I've never been able to reproduce locally. + assert loop._thread is not None + loop._thread.join() + _selector_loops.clear() + + +atexit.register(_atexit_callback) + + +# SelectorThread from tornado 6.4.0 + + +class SelectorThread: + """Define ``add_reader`` methods to be called in a background select thread. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; + all callbacks are run on the wrapped event loop's thread. + + Typically used via ``AddThreadSelectorEventLoop``, + but can be attached to a running asyncio loop. + """ + + _closed = False + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + + self._select_cond = threading.Condition() + self._select_args: ( + tuple[list[_FileDescriptorLike], list[_FileDescriptorLike]] | None + ) = None + self._closing_selector = False + self._thread: threading.Thread | None = None + self._thread_manager_handle = self._thread_manager() + + async def thread_manager_anext() -> None: + # the anext builtin wasn't added until 3.10. We just need to iterate + # this generator one step. + await self._thread_manager_handle.__anext__() + + # When the loop starts, start the thread. Not too soon because we can't + # clean up if we get to this point but the event loop is closed without + # starting. + self._real_loop.call_soon( + lambda: self._real_loop.create_task(thread_manager_anext()) + ) + + self._readers: dict[_FileDescriptorLike, Callable] = {} + self._writers: dict[_FileDescriptorLike, Callable] = {} + + # Writing to _waker_w will wake up the selector thread, which + # watches for _waker_r to be readable. + self._waker_r, self._waker_w = socket.socketpair() + self._waker_r.setblocking(False) + self._waker_w.setblocking(False) + _selector_loops.add(self) + self.add_reader(self._waker_r, self._consume_waker) + + def close(self) -> None: + if self._closed: + return + with self._select_cond: + self._closing_selector = True + self._select_cond.notify() + self._wake_selector() + if self._thread is not None: + self._thread.join() + _selector_loops.discard(self) + self.remove_reader(self._waker_r) + self._waker_r.close() + self._waker_w.close() + self._closed = True + + async def _thread_manager(self) -> typing.AsyncGenerator[None, None]: + # Create a thread to run the select system call. We manage this thread + # manually so we can trigger a clean shutdown from an atexit hook. Note + # that due to the order of operations at shutdown, only daemon threads + # can be shut down in this way (non-daemon threads would require the + # introduction of a new hook: https://bugs.python.org/issue41962) + self._thread = threading.Thread( + name="Tornado selector", + daemon=True, + target=self._run_select, + ) + self._thread.start() + self._start_select() + try: + # The presense of this yield statement means that this coroutine + # is actually an asynchronous generator, which has a special + # shutdown protocol. We wait at this yield point until the + # event loop's shutdown_asyncgens method is called, at which point + # we will get a GeneratorExit exception and can shut down the + # selector thread. + yield + except GeneratorExit: + self.close() + raise + + def _wake_selector(self) -> None: + if self._closed: + return + try: + self._waker_w.send(b"a") + except BlockingIOError: + pass + + def _consume_waker(self) -> None: + try: + self._waker_r.recv(1024) + except BlockingIOError: + pass + + def _start_select(self) -> None: + # Capture reader and writer sets here in the event loop + # thread to avoid any problems with concurrent + # modification while the select loop uses them. + with self._select_cond: + assert self._select_args is None + self._select_args = (list(self._readers.keys()), list(self._writers.keys())) + self._select_cond.notify() + + def _run_select(self) -> None: + while True: + with self._select_cond: + while self._select_args is None and not self._closing_selector: + self._select_cond.wait() + if self._closing_selector: + return + assert self._select_args is not None + to_read, to_write = self._select_args + self._select_args = None + + # We use the simpler interface of the select module instead of + # the more stateful interface in the selectors module because + # this class is only intended for use on windows, where + # select.select is the only option. The selector interface + # does not have well-documented thread-safety semantics that + # we can rely on so ensuring proper synchronization would be + # tricky. + try: + # On windows, selecting on a socket for write will not + # return the socket when there is an error (but selecting + # for reads works). Also select for errors when selecting + # for writes, and merge the results. + # + # This pattern is also used in + # https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317 + rs, ws, xs = select.select(to_read, to_write, to_write) + ws = ws + xs + except OSError as e: + # After remove_reader or remove_writer is called, the file + # descriptor may subsequently be closed on the event loop + # thread. It's possible that this select thread hasn't + # gotten into the select system call by the time that + # happens in which case (at least on macOS), select may + # raise a "bad file descriptor" error. If we get that + # error, check and see if we're also being woken up by + # polling the waker alone. If we are, just return to the + # event loop and we'll get the updated set of file + # descriptors on the next iteration. Otherwise, raise the + # original error. + if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF): + rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) + if rs: + ws = [] + else: + raise + else: + raise + + try: + self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws) + except RuntimeError: + # "Event loop is closed". Swallow the exception for + # consistency with PollIOLoop (and logical consistency + # with the fact that we can't guarantee that an + # add_callback that completes without error will + # eventually execute). + pass + except AttributeError: + # ProactorEventLoop may raise this instead of RuntimeError + # if call_soon_threadsafe races with a call to close(). + # Swallow it too for consistency. + pass + + def _handle_select( + self, rs: list[_FileDescriptorLike], ws: list[_FileDescriptorLike] + ) -> None: + for r in rs: + self._handle_event(r, self._readers) + for w in ws: + self._handle_event(w, self._writers) + self._start_select() + + def _handle_event( + self, + fd: _FileDescriptorLike, + cb_map: dict[_FileDescriptorLike, Callable], + ) -> None: + try: + callback = cb_map[fd] + except KeyError: + return + callback() + + def add_reader( + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + self._readers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def add_writer( + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + self._writers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def remove_reader(self, fd: _FileDescriptorLike) -> bool: + try: + del self._readers[fd] + except KeyError: + return False + self._wake_selector() + return True + + def remove_writer(self, fd: _FileDescriptorLike) -> bool: + try: + del self._writers[fd] + except KeyError: + return False + self._wake_selector() + return True + + +# AddThreadSelectorEventLoop: unmodified from tornado 6.4.0 +class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop): + """Wrap an event loop to add implementations of the ``add_reader`` method family. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; all callbacks are + run on the wrapped event loop's thread. + + This class is used automatically by Tornado; applications should not need + to refer to it directly. + + It is safe to wrap any event loop with this class, although it only makes sense + for event loops that do not implement the ``add_reader`` family of methods + themselves (i.e. ``WindowsProactorEventLoop``) + + Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop. + """ + + # This class is a __getattribute__-based proxy. All attributes other than those + # in this set are proxied through to the underlying loop. + MY_ATTRIBUTES = { + "_real_loop", + "_selector", + "add_reader", + "add_writer", + "close", + "remove_reader", + "remove_writer", + } + + def __getattribute__(self, name: str) -> Any: + if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES: + return super().__getattribute__(name) + return getattr(self._real_loop, name) + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + self._selector = SelectorThread(real_loop) + + def close(self) -> None: + self._selector.close() + self._real_loop.close() + + def add_reader( # type: ignore[override] + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + return self._selector.add_reader(fd, callback, *args) + + def add_writer( # type: ignore[override] + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + return self._selector.add_writer(fd, callback, *args) + + def remove_reader(self, fd: _FileDescriptorLike) -> bool: + return self._selector.remove_reader(fd) + + def remove_writer(self, fd: _FileDescriptorLike) -> bool: + return self._selector.remove_writer(fd) + + +# registry of asyncio loop : selector thread +_selectors: WeakKeyDictionary = WeakKeyDictionary() + + +def _set_selector_windows() -> None: + """Set selector-compatible loop. + Sets ``add_reader`` family of methods on the asyncio loop. + Workaround Windows proactor removal of *reader methods. + """ + if not ( + sys.platform == "win32" + and current_async_library() == "asyncio" + and asyncio.get_event_loop_policy().__class__.__name__ + == "WindowsProactorEventLoopPolicy" + ): + return + + asyncio_loop = asyncio.get_running_loop() + if asyncio_loop in _selectors: + return + + from ._selector_thread import AddThreadSelectorEventLoop + + selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop( # type: ignore[abstract] + asyncio_loop + ) + + # patch loop.close to also close the selector thread + loop_close = asyncio_loop.close + + def _close_selector_and_loop() -> None: + # restore original before calling selector.close, + # which in turn calls eventloop.close! + asyncio_loop.close = loop_close # type: ignore[method-assign] + _selectors.pop(asyncio_loop, None) + selector_loop.close() + + asyncio_loop.close = _close_selector_and_loop # type: ignore[method-assign] + asyncio_loop.add_reader = selector_loop.add_reader # type: ignore[assignment] + asyncio_loop.remove_reader = selector_loop.remove_reader # type: ignore[method-assign] diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index f728c26..71a3686 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -2,7 +2,6 @@ import select import selectors -import threading import warnings from collections import deque from contextlib import AsyncExitStack @@ -18,7 +17,7 @@ cast, ) -from anyio import Event, TASK_STATUS_IGNORED, create_task_group, from_thread, sleep, to_thread, wait_socket_readable +from anyio import Event, TASK_STATUS_IGNORED, create_task_group, sleep, wait_socket_readable from anyio.abc import TaskGroup, TaskStatus from anyioutils import Future, Task, create_task @@ -26,6 +25,8 @@ from zmq import EVENTS, POLLIN, POLLOUT from zmq.utils import jsonapi +from ._selector_thread import _set_selector_windows + class _FutureEvent(NamedTuple): future: Future @@ -166,9 +167,6 @@ class Socket(zmq.Socket): _fd = None _exit_stack = None _task_group = None - _select_socket_r = None - _select_socket_w = None - _stopped = None started = None def __init__( @@ -194,17 +192,9 @@ def __init__( self._send_futures = deque() self._state = 0 self._fd = self._shadow_sock.FD - self._select_socket_r, self._select_socket_w = socketpair() - self._select_socket_r.setblocking(False) - self._select_socket_w.setblocking(False) self.started = Event() - self._stopped = threading.Event() def close(self, linger: int | None = None) -> None: - assert self._stopped is not None - assert self._select_socket_w is not None - self._stopped.set() - self._select_socket_w.send(b"a") if not self.closed and self._fd is not None: event_list: list[_FutureEvent] = list( chain(self._recv_futures or [], self._send_futures or []) @@ -689,29 +679,32 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): self.close() except BaseException: pass - self._task_group.cancel_scope.cancel() + await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + if self._task_group is None: + async with create_task_group() as self._task_group: + await self._task_group.start(self._start) + else: + await self._task_group.start(self._start) + task_status.started() + + async def stop(self): + assert self._task_group is not None + self._task_group.cancel_scope.cancel() + + async def _start(self, *, task_status: TaskStatus[None]): + _set_selector_windows() assert self._task_group is not None assert self.started is not None - self._task_group.start_soon(partial(to_thread.run_sync, self._reader, abandon_on_cancel=True)) - await self.started.wait() + if self.started.is_set(): + raise RuntimeError("Socket already started") + self.started.set() task_status.started() - - def _reader(self): - from_thread.run_sync(self.started.set) while True: - try: - rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()]) - except OSError as e: - return - if self._stopped.is_set(): - return - self._read() - - def _read(self): - from_thread.run(self._handle_events) + await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type] + await self._handle_events() def _clear_io_state(self): """unregister the ioloop event handler diff --git a/tests/test_socket.py b/tests/test_socket.py index a35620e..3feb6a4 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -50,7 +50,7 @@ async def recv(messages): async def test_arecv_send(context, create_bound_pair): a, b = create_bound_pair(zmq.REQ, zmq.REP) a, b = Socket(a), Socket(b) - async with a, b, create_task_group() as tg: + async with b, a, create_task_group() as tg: async def recv(messages): for message in messages: @@ -100,9 +100,6 @@ async def test_start_socket(total_threads, create_bound_pair): raise RuntimeError assert b_started - if total_threads == 1: - assert not a_started - else: - assert a_started + assert a_started to_thread.current_default_thread_limiter().total_tokens = 40