Skip to content

Commit

Permalink
Add test for starting sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 8, 2024
1 parent 3167286 commit 2ea2d53
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
33 changes: 20 additions & 13 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import select
import selectors
import threading
import warnings
from collections import deque
from contextlib import AsyncExitStack
from functools import partial
from itertools import chain
from socket import socketpair
from threading import Event
from typing import (
Any,
Awaitable,
Expand All @@ -18,8 +18,8 @@
cast,
)

from anyio import create_task_group, from_thread, sleep, to_thread, wait_socket_readable
from anyio.abc import TaskGroup
from anyio import Event, TASK_STATUS_IGNORED, create_task_group, from_thread, sleep, to_thread, wait_socket_readable
from anyio.abc import TaskGroup, TaskStatus
from anyioutils import Future, Task, create_task

import zmq
Expand Down Expand Up @@ -166,9 +166,10 @@ class Socket(zmq.Socket):
_fd = None
_exit_stack = None
_task_group = None
_stop_event = None
_select_socket_r = None
_select_socket_w = None
_stopped = None
_started = None

def __init__(
self,
Expand All @@ -193,15 +194,16 @@ def __init__(
self._send_futures = deque()
self._state = 0
self._fd = self._shadow_sock.FD
self._stop_event = Event()
self._select_socket_r, self._select_socket_w = socketpair()
self._select_socket_r.setblocking(False)
self._select_socket_w.setblocking(False)
self._started = Event()
self._stopped = threading.Event()

def close(self, linger: int | None = None) -> None:
assert self._stop_event is not None
assert self._stopped is not None
assert self._select_socket_w is not None
self._stop_event.set()
self._stopped.set()
self._select_socket_w.send(b"a")
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
Expand Down Expand Up @@ -678,26 +680,31 @@ async def __aenter__(self) -> Socket:
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
self._task_group.start_soon(self.start)
await self._task_group.start(self.start)

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
try:
self.close()
except BaseException:
pass
self._task_group.cancel_scope.cancel()
self.close()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self):
await to_thread.run_sync(self._reader, abandon_on_cancel=True)
#create_task(self._handle_events(task_group), task_group)
async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
self._task_group.start_soon(partial(to_thread.run_sync, self._reader, abandon_on_cancel=True))
await self._started.wait()
task_status.started()

def _reader(self):
from_thread.run_sync(self._started.set)
while True:
try:
rs, ws, xs = select.select([self._shadow_sock, self._select_socket_r.fileno()], [], [self._shadow_sock, self._select_socket_r.fileno()])
except OSError as e:
return
if self._stop_event.is_set():
if self._stopped.is_set():
return
self._read()

Expand Down
27 changes: 26 additions & 1 deletion tests/test_socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import zmq
from anyio import create_task_group, sleep
from anyio import create_task_group, move_on_after, sleep, to_thread
from zmq_anyio import Socket

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -81,3 +81,28 @@ async def recv():
tg.start_soon(recv)
await sleep(0.1)
a.send(b"hi")


@pytest.mark.parametrize("total_threads", [1, 2])
async def test_start_socket(total_threads, create_bound_pair):
to_thread.current_default_thread_limiter().total_tokens = total_threads

a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP))
a_started = False
b_started = False

with pytest.raises(BaseException):
async with b:
b_started = True
with move_on_after(0.1):
async with a:
a_started = True
raise RuntimeError

assert b_started
if total_threads == 1:
assert not a_started
else:
assert a_started

to_thread.current_default_thread_limiter().total_tokens = 40

0 comments on commit 2ea2d53

Please sign in to comment.