diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 60bcab8..82e95cc 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -6,6 +6,7 @@ from contextlib import AsyncExitStack from functools import partial from itertools import chain +from threading import Event from typing import ( Any, Awaitable, @@ -164,6 +165,7 @@ class Socket(zmq.Socket): _selector = None _exit_stack = None _task_group = None + _stop_event = None def __init__( self, @@ -188,10 +190,9 @@ def __init__( self._send_futures = deque() self._state = 0 self._fd = self._shadow_sock.FD + self._stop_event = Event() def close(self, linger: int | None = None) -> None: - assert self._selector is not None - self._selector.unregister(self._shadow_sock) if not self.closed and self._fd is not None: event_list: list[_FutureEvent] = list( chain(self._recv_futures or [], self._send_futures or []) @@ -205,6 +206,14 @@ def close(self, linger: int | None = None) -> None: pass self._clear_io_state() super().close(linger=linger) + assert self._stop_event is not None + self._stop_event.set() + try: + assert self._selector is not None + self._selector.unregister(self._shadow_sock) + self._selector.close() + except BaseException: + pass close.__doc__ = zmq.Socket.close.__doc__ @@ -681,16 +690,16 @@ async def start(self): raise RuntimeError("Socket already started") self._selector = selectors.DefaultSelector() - self._selector.register(self._shadow_sock, selectors.EVENT_READ, self._read) + self._selector.register(self._shadow_sock, selectors.EVENT_READ) await to_thread.run_sync(self._reader, abandon_on_cancel=True) #create_task(self._handle_events(task_group), task_group) def _reader(self): while True: - events = self._selector.select() - for key, mask in events: - callback = key.data - callback() + events = self._selector.select(0.1) + if self._stop_event.is_set(): + return + self._read() def _read(self): from_thread.run(self._handle_events)