Skip to content

Commit

Permalink
Check if socket started when calling async methods
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 14, 2024
1 parent 5f1ca66 commit 983f4dc
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,24 @@ def __init__(
self.started = Event()

def close(self, linger: int | None = None) -> None:
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
)
for event in event_list:
if not event.future.done():
try:
event.future.cancel(raise_exception=False)
except RuntimeError:
# RuntimeError may be called during teardown
pass
super().close(linger=linger)
try:
if not self.closed and self._fd is not None:
event_list: list[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
)
for event in event_list:
if not event.future.done():
try:
event.future.cancel(raise_exception=False)
except RuntimeError:
# RuntimeError may be called during teardown
pass
super().close(linger=linger)
except BaseException:
pass

if self._task_group is not None:
self._task_group.cancel_scope.cancel()

close.__doc__ = zmq.Socket.close.__doc__

Expand All @@ -224,6 +230,7 @@ async def arecv(
copy: bool = True,
track: bool = False,
) -> bytes | zmq.Frame:
self._check_started()
return await self._add_recv_event(
"recv", dict(flags=flags, copy=copy, track=track)
)
Expand Down Expand Up @@ -315,6 +322,7 @@ async def arecv_multipart(
copy: bool = True,
track: bool = False,
) -> list[bytes] | list[zmq.Frame]:
self._check_started()
return await self._add_recv_event(
"recv_multipart", dict(flags=flags, copy=copy, track=track)
)
Expand All @@ -339,6 +347,7 @@ async def asend(
track: bool = False,
**kwargs: Any,
) -> zmq.MessageTracker | None:
self._check_started()
kwargs["flags"] = flags
kwargs["copy"] = copy
kwargs["track"] = track
Expand Down Expand Up @@ -431,6 +440,7 @@ async def asend_multipart(
track: bool = False,
**kwargs,
) -> zmq.MessageTracker | None:
self._check_started()
kwargs["flags"] = flags
kwargs["copy"] = copy
kwargs["track"] = track
Expand All @@ -447,7 +457,7 @@ async def apoll(self, timeout=None, flags=zmq.POLLIN) -> int: # type: ignore
returns a Future for the poll results.
"""

self._check_started()
if self.closed:
raise zmq.ZMQError(zmq.ENOTSUP)

Expand Down Expand Up @@ -783,10 +793,6 @@ async def __aenter__(self) -> Socket:
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
try:
self.close()
except BaseException:
pass
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

Expand All @@ -796,29 +802,30 @@ async def start(
if self._task_group is None:
async with create_task_group() as self._task_group:
await self._task_group.start(self._start)
task_status.started()
else:
await self._task_group.start(self._start)
task_status.started()
task_status.started()

async def stop(self):
if self._task_group is None:
return

self._task_group.cancel_scope.cancel()
self.close()

async def _start(self, *, task_status: TaskStatus[None]):
_set_selector_windows()
assert self._task_group is not None
assert self.started is not None
task_status.started()
if self.started.is_set():
task_status.started()
return

self.started.set()
task_status.started()
try:
while True:
await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type]
await self._handle_events()
except Exception:
pass

def _check_started(self):
if self._task_group is None:
raise RuntimeError("Socket must be used with async context manager (or `await sock.start()`)")

0 comments on commit 983f4dc

Please sign in to comment.