diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 82e95cc..580b85e 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -1,11 +1,13 @@ from __future__ import annotations +import select import selectors import warnings from collections import deque from contextlib import AsyncExitStack from functools import partial from itertools import chain +from socket import socketpair from threading import Event from typing import ( Any, @@ -162,10 +164,11 @@ class Socket(zmq.Socket): _shadow_sock: zmq.Socket _poller_class = _AsyncPoller _fd = None - _selector = None _exit_stack = None _task_group = None _stop_event = None + _select_socket_r = None + _select_socket_w = None def __init__( self, @@ -191,8 +194,15 @@ def __init__( self._state = 0 self._fd = self._shadow_sock.FD self._stop_event = Event() + self._select_socket_r, self._select_socket_w = socketpair() + self._select_socket_r.setblocking(False) + self._select_socket_w.setblocking(False) def close(self, linger: int | None = None) -> None: + assert self._stop_event is not None + assert self._select_socket_w is not None + self._stop_event.set() + self._select_socket_w.send(b"") if not self.closed and self._fd is not None: event_list: list[_FutureEvent] = list( chain(self._recv_futures or [], self._send_futures or []) @@ -206,14 +216,6 @@ def close(self, linger: int | None = None) -> None: pass self._clear_io_state() super().close(linger=linger) - assert self._stop_event is not None - self._stop_event.set() - try: - assert self._selector is not None - self._selector.unregister(self._shadow_sock) - self._selector.close() - except BaseException: - pass close.__doc__ = zmq.Socket.close.__doc__ @@ -686,17 +688,15 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self): - if self._selector is not None: - raise RuntimeError("Socket already started") - - self._selector = selectors.DefaultSelector() - self._selector.register(self._shadow_sock, selectors.EVENT_READ) await to_thread.run_sync(self._reader, abandon_on_cancel=True) #create_task(self._handle_events(task_group), task_group) def _reader(self): while True: - events = self._selector.select(0.1) + try: + rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()]) + except OSError as e: + return if self._stop_event.is_set(): return self._read()