diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 0daa65f..bfdb260 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -195,18 +195,24 @@ def __init__( self.started = Event() def close(self, linger: int | None = None) -> None: - if not self.closed and self._fd is not None: - event_list: list[_FutureEvent] = list( - chain(self._recv_futures or [], self._send_futures or []) - ) - for event in event_list: - if not event.future.done(): - try: - event.future.cancel(raise_exception=False) - except RuntimeError: - # RuntimeError may be called during teardown - pass - super().close(linger=linger) + try: + if not self.closed and self._fd is not None: + event_list: list[_FutureEvent] = list( + chain(self._recv_futures or [], self._send_futures or []) + ) + for event in event_list: + if not event.future.done(): + try: + event.future.cancel(raise_exception=False) + except RuntimeError: + # RuntimeError may be called during teardown + pass + super().close(linger=linger) + except BaseException: + pass + + if self._task_group is not None: + self._task_group.cancel_scope.cancel() close.__doc__ = zmq.Socket.close.__doc__ @@ -224,6 +230,7 @@ async def arecv( copy: bool = True, track: bool = False, ) -> bytes | zmq.Frame: + self._check_started() return await self._add_recv_event( "recv", dict(flags=flags, copy=copy, track=track) ) @@ -315,6 +322,7 @@ async def arecv_multipart( copy: bool = True, track: bool = False, ) -> list[bytes] | list[zmq.Frame]: + self._check_started() return await self._add_recv_event( "recv_multipart", dict(flags=flags, copy=copy, track=track) ) @@ -339,6 +347,7 @@ async def asend( track: bool = False, **kwargs: Any, ) -> zmq.MessageTracker | None: + self._check_started() kwargs["flags"] = flags kwargs["copy"] = copy kwargs["track"] = track @@ -431,6 +440,7 @@ async def asend_multipart( track: bool = False, **kwargs, ) -> zmq.MessageTracker | None: + self._check_started() kwargs["flags"] = flags kwargs["copy"] = copy kwargs["track"] = track @@ -447,7 +457,7 @@ async def apoll(self, timeout=None, flags=zmq.POLLIN) -> int: # type: ignore returns a Future for the poll results. """ - + self._check_started() if self.closed: raise zmq.ZMQError(zmq.ENOTSUP) @@ -783,10 +793,6 @@ async def __aenter__(self) -> Socket: return self async def __aexit__(self, exc_type, exc_value, exc_tb): - try: - self.close() - except BaseException: - pass await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) @@ -796,29 +802,30 @@ async def start( if self._task_group is None: async with create_task_group() as self._task_group: await self._task_group.start(self._start) + task_status.started() else: await self._task_group.start(self._start) - task_status.started() + task_status.started() async def stop(self): - if self._task_group is None: - return - - self._task_group.cancel_scope.cancel() + self.close() async def _start(self, *, task_status: TaskStatus[None]): _set_selector_windows() assert self._task_group is not None assert self.started is not None + task_status.started() if self.started.is_set(): - task_status.started() return self.started.set() - task_status.started() try: while True: await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type] await self._handle_events() except Exception: pass + + def _check_started(self): + if self._task_group is None: + raise RuntimeError("Socket must be used with async context manager (or `await sock.start()`)")