Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 7, 2024
1 parent bca1a71 commit 1bb478c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[![Build Status](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++)](https://github.com/davidbrochart/zmq-anyio/actions/workflows/test.yml/badge.svg?query=branch%3Amain++)
[![Code Coverage](https://img.shields.io/badge/coverage-100%25-green)](https://img.shields.io/badge/coverage-100%25-green)

# zmq-anyio

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
requires-python = ">= 3.9"
dependencies = [
"anyio",
"anyioutils",
"anyioutils >=0.4.6",
"pyzmq >=26.0.0,<27.0.0",
]

Expand Down
13 changes: 8 additions & 5 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class _FutureEvent(NamedTuple):
class _AsyncPoller(zmq.Poller):
"""Poller that returns a Future on poll, instead of blocking."""

_socket_class: type[_AsyncSocket]
_socket_class: type[Socket]
raw_sockets: list[Any]

def _watch_raw_socket(self, socket: Any, evt: int, f: Callable) -> None:
Expand Down Expand Up @@ -71,7 +71,7 @@ def wake_raw(*args):

watcher.add_done_callback(lambda f: self._unwatch_raw_sockets(*raw_sockets))

wrapped_sockets: list[_AsyncSocket] = []
wrapped_sockets: list[Socket] = []

def _clear_wrapper_io(f):
for s in wrapped_sockets:
Expand All @@ -81,7 +81,7 @@ def _clear_wrapper_io(f):
if isinstance(socket, zmq.Socket):
if not isinstance(socket, self._socket_class):
# it's a blocking zmq.Socket, wrap it in async
socket = self._socket_class.from_socket(socket)
socket = self._socket_class(socket)
wrapped_sockets.append(socket)
if mask & zmq.POLLIN:
create_task(
Expand Down Expand Up @@ -190,6 +190,7 @@ def __init__(
self._fd = self._shadow_sock.FD

def close(self, linger: int | None = None) -> None:
assert self._selector is not None
self._selector.unregister(self._shadow_sock)
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
Expand Down Expand Up @@ -265,7 +266,7 @@ async def asend_json(
obj: Any,
flags: int = 0,
**kwargs,
) -> None:
):
send_kwargs = {}
for key in ("routing_id", "group"):
if key in kwargs:
Expand All @@ -275,7 +276,7 @@ async def asend_json(

async def asend_multipart(
self,
msg_parts: Sequence[bytes],
msg_parts: list[bytes],
flags: int = 0,
copy: bool = True,
track: bool = False,
Expand Down Expand Up @@ -668,6 +669,8 @@ async def __aenter__(self) -> Socket:
self._exit_stack = exit_stack.pop_all()
self._task_group.start_soon(self.start)

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
self.close()
Expand Down

0 comments on commit 1bb478c

Please sign in to comment.