Skip to content

Commit

Permalink
Fix tests (#1)
Browse files Browse the repository at this point in the history
* Fix tests

* Add stop event

* Use select instead of selectors
  • Loading branch information
davidbrochart authored Nov 7, 2024
1 parent bca1a71 commit 3dab269
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[![Build Status](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++)](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++)
[![Code Coverage](https://img.shields.io/badge/coverage-100%25-green)](https://img.shields.io/badge/coverage-100%25-green)

# zmq-anyio

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
requires-python = ">= 3.9"
dependencies = [
"anyio",
"anyioutils",
"anyioutils >=0.4.6",
"pyzmq >=26.0.0,<27.0.0",
]

Expand Down
44 changes: 28 additions & 16 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import select
import selectors
import warnings
from collections import deque
from contextlib import AsyncExitStack
from functools import partial
from itertools import chain
from socket import socketpair
from threading import Event
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -35,7 +38,7 @@ class _FutureEvent(NamedTuple):
class _AsyncPoller(zmq.Poller):
"""Poller that returns a Future on poll, instead of blocking."""

_socket_class: type[_AsyncSocket]
_socket_class: type[Socket]
raw_sockets: list[Any]

def _watch_raw_socket(self, socket: Any, evt: int, f: Callable) -> None:
Expand Down Expand Up @@ -71,7 +74,7 @@ def wake_raw(*args):

watcher.add_done_callback(lambda f: self._unwatch_raw_sockets(*raw_sockets))

wrapped_sockets: list[_AsyncSocket] = []
wrapped_sockets: list[Socket] = []

def _clear_wrapper_io(f):
for s in wrapped_sockets:
Expand All @@ -81,7 +84,7 @@ def _clear_wrapper_io(f):
if isinstance(socket, zmq.Socket):
if not isinstance(socket, self._socket_class):
# it's a blocking zmq.Socket, wrap it in async
socket = self._socket_class.from_socket(socket)
socket = self._socket_class(socket)
wrapped_sockets.append(socket)
if mask & zmq.POLLIN:
create_task(
Expand Down Expand Up @@ -161,9 +164,11 @@ class Socket(zmq.Socket):
_shadow_sock: zmq.Socket
_poller_class = _AsyncPoller
_fd = None
_selector = None
_exit_stack = None
_task_group = None
_stop_event = None
_select_socket_r = None
_select_socket_w = None

def __init__(
self,
Expand All @@ -188,9 +193,16 @@ def __init__(
self._send_futures = deque()
self._state = 0
self._fd = self._shadow_sock.FD
self._stop_event = Event()
self._select_socket_r, self._select_socket_w = socketpair()
self._select_socket_r.setblocking(False)
self._select_socket_w.setblocking(False)

def close(self, linger: int | None = None) -> None:
self._selector.unregister(self._shadow_sock)
assert self._stop_event is not None
assert self._select_socket_w is not None
self._stop_event.set()
self._select_socket_w.send(b"a")
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
Expand Down Expand Up @@ -265,7 +277,7 @@ async def asend_json(
obj: Any,
flags: int = 0,
**kwargs,
) -> None:
):
send_kwargs = {}
for key in ("routing_id", "group"):
if key in kwargs:
Expand All @@ -275,7 +287,7 @@ async def asend_json(

async def asend_multipart(
self,
msg_parts: Sequence[bytes],
msg_parts: list[bytes],
flags: int = 0,
copy: bool = True,
track: bool = False,
Expand Down Expand Up @@ -668,26 +680,26 @@ async def __aenter__(self) -> Socket:
self._exit_stack = exit_stack.pop_all()
self._task_group.start_soon(self.start)

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
self.close()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self):
if self._selector is not None:
raise RuntimeError("Socket already started")

self._selector = selectors.DefaultSelector()
self._selector.register(self._shadow_sock, selectors.EVENT_READ, self._read)
await to_thread.run_sync(self._reader, abandon_on_cancel=True)
#create_task(self._handle_events(task_group), task_group)

def _reader(self):
while True:
events = self._selector.select()
for key, mask in events:
callback = key.data
callback()
try:
rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()])
except OSError as e:
return
if self._stop_event.is_set():
return
self._read()

def _read(self):
from_thread.run(self._handle_events)
Expand Down

0 comments on commit 3dab269

Please sign in to comment.