Skip to content

Commit

Permalink
Add stop event
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 7, 2024
1 parent 1bb478c commit 3c33c80
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import AsyncExitStack
from functools import partial
from itertools import chain
from threading import Event
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -164,6 +165,7 @@ class Socket(zmq.Socket):
_selector = None
_exit_stack = None
_task_group = None
_stop_event = None

def __init__(
self,
Expand All @@ -188,10 +190,9 @@ def __init__(
self._send_futures = deque()
self._state = 0
self._fd = self._shadow_sock.FD
self._stop_event = Event()

def close(self, linger: int | None = None) -> None:
assert self._selector is not None
self._selector.unregister(self._shadow_sock)
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
Expand All @@ -205,6 +206,14 @@ def close(self, linger: int | None = None) -> None:
pass
self._clear_io_state()
super().close(linger=linger)
assert self._stop_event is not None
self._stop_event.set()
try:
assert self._selector is not None
self._selector.unregister(self._shadow_sock)
self._selector.close()
except BaseException:
pass

close.__doc__ = zmq.Socket.close.__doc__

Expand Down Expand Up @@ -681,16 +690,16 @@ async def start(self):
raise RuntimeError("Socket already started")

self._selector = selectors.DefaultSelector()
self._selector.register(self._shadow_sock, selectors.EVENT_READ, self._read)
self._selector.register(self._shadow_sock, selectors.EVENT_READ)
await to_thread.run_sync(self._reader, abandon_on_cancel=True)
#create_task(self._handle_events(task_group), task_group)

def _reader(self):
while True:
events = self._selector.select()
for key, mask in events:
callback = key.data
callback()
events = self._selector.select(0.1)
if self._stop_event.is_set():
return
self._read()

def _read(self):
from_thread.run(self._handle_events)
Expand Down

0 comments on commit 3c33c80

Please sign in to comment.