diff --git a/README.md b/README.md index 774605d..290f156 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ [![Build Status](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++)](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++) -[![Code Coverage](https://img.shields.io/badge/coverage-100%25-green)](https://img.shields.io/badge/coverage-100%25-green) # zmq-anyio diff --git a/pyproject.toml b/pyproject.toml index 8cda646..8280fcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ requires-python = ">= 3.9" dependencies = [ "anyio", - "anyioutils", + "anyioutils >=0.4.6", "pyzmq >=26.0.0,<27.0.0", ] diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 0d179eb..9876083 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -1,11 +1,14 @@ from __future__ import annotations +import select import selectors import warnings from collections import deque from contextlib import AsyncExitStack from functools import partial from itertools import chain +from socket import socketpair +from threading import Event from typing import ( Any, Awaitable, @@ -35,7 +38,7 @@ class _FutureEvent(NamedTuple): class _AsyncPoller(zmq.Poller): """Poller that returns a Future on poll, instead of blocking.""" - _socket_class: type[_AsyncSocket] + _socket_class: type[Socket] raw_sockets: list[Any] def _watch_raw_socket(self, socket: Any, evt: int, f: Callable) -> None: @@ -71,7 +74,7 @@ def wake_raw(*args): watcher.add_done_callback(lambda f: self._unwatch_raw_sockets(*raw_sockets)) - wrapped_sockets: list[_AsyncSocket] = [] + wrapped_sockets: list[Socket] = [] def _clear_wrapper_io(f): for s in wrapped_sockets: @@ -81,7 +84,7 @@ def _clear_wrapper_io(f): if isinstance(socket, zmq.Socket): if not isinstance(socket, self._socket_class): # it's a blocking zmq.Socket, wrap it in async - socket = self._socket_class.from_socket(socket) + socket = self._socket_class(socket) wrapped_sockets.append(socket) if mask & zmq.POLLIN: create_task( @@ -161,9 +164,11 @@ class Socket(zmq.Socket): _shadow_sock: zmq.Socket _poller_class = _AsyncPoller _fd = None - _selector = None _exit_stack = None _task_group = None + _stop_event = None + _select_socket_r = None + _select_socket_w = None def __init__( self, @@ -188,9 +193,16 @@ def __init__( self._send_futures = deque() self._state = 0 self._fd = self._shadow_sock.FD + self._stop_event = Event() + self._select_socket_r, self._select_socket_w = socketpair() + self._select_socket_r.setblocking(False) + self._select_socket_w.setblocking(False) def close(self, linger: int | None = None) -> None: - self._selector.unregister(self._shadow_sock) + assert self._stop_event is not None + assert self._select_socket_w is not None + self._stop_event.set() + self._select_socket_w.send(b"a") if not self.closed and self._fd is not None: event_list: list[_FutureEvent] = list( chain(self._recv_futures or [], self._send_futures or []) @@ -265,7 +277,7 @@ async def asend_json( obj: Any, flags: int = 0, **kwargs, - ) -> None: + ): send_kwargs = {} for key in ("routing_id", "group"): if key in kwargs: @@ -275,7 +287,7 @@ async def asend_json( async def asend_multipart( self, - msg_parts: Sequence[bytes], + msg_parts: list[bytes], flags: int = 0, copy: bool = True, track: bool = False, @@ -668,26 +680,26 @@ async def __aenter__(self) -> Socket: self._exit_stack = exit_stack.pop_all() self._task_group.start_soon(self.start) + return self + async def __aexit__(self, exc_type, exc_value, exc_tb): self._task_group.cancel_scope.cancel() self.close() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self): - if self._selector is not None: - raise RuntimeError("Socket already started") - - self._selector = selectors.DefaultSelector() - self._selector.register(self._shadow_sock, selectors.EVENT_READ, self._read) await to_thread.run_sync(self._reader, abandon_on_cancel=True) #create_task(self._handle_events(task_group), task_group) def _reader(self): while True: - events = self._selector.select() - for key, mask in events: - callback = key.data - callback() + try: + rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()]) + except OSError as e: + return + if self._stop_event.is_set(): + return + self._read() def _read(self): from_thread.run(self._handle_events)