diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 70b3bf40..6bde2ef7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,30 +32,45 @@ jobs: python-version: "3.11" - os: ubuntu-latest python-version: "3.12" + - os: windows-latest + python-version: "3.10" + - os: windows-latest + python-version: "3.11" + - os: windows-latest + python-version: "3.12" steps: - name: Checkout uses: actions/checkout@v4 - - name: Base Setup - uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install hatch + run: | + python --version + python -m pip install hatch - name: Run the tests timeout-minutes: 15 if: ${{ !startsWith( matrix.python-version, 'pypy' ) && !startsWith(matrix.os, 'windows') }} run: | - hatch run cov:test --cov-fail-under 50 || hatch run test:test --lf + PYTHONTRACEMALLOC=20 hatch run cov:test --cov-fail-under 50 || PYTHONTRACEMALLOC=20 hatch run test:test --lf - name: Run the tests on pypy timeout-minutes: 15 if: ${{ startsWith( matrix.python-version, 'pypy' ) }} run: | - hatch run test:nowarn || hatch run test:nowarn --lf + PYTHONTRACEMALLOC=20 hatch run test:nowarn || PYTHONTRACEMALLOC=20 hatch run test:nowarn --lf - name: Run the tests on Windows timeout-minutes: 15 if: ${{ startsWith(matrix.os, 'windows') }} run: | - hatch run cov:nowarn || hatch run test:nowarn --lf + hatch run test:pip list + hatch run test:python --version + #hatch run test:python -m pip install git+https://github.com/ipython/ipython@92dd9e47fe8862ee38770744c165b680cb5241b1 + PYTHONTRACEMALLOC=20 hatch run test:pytest -v - name: Check Launcher run: | @@ -138,7 +153,7 @@ jobs: - name: Run the tests timeout-minutes: 15 - run: pytest -W default -vv || pytest --vv -W default --lf + run: PYTHONTRACEMALLOC=20 pytest -W default -vv || PYTHONTRACEMALLOC=20 pytest --vv -W default --lf test_miniumum_versions: name: Test Minimum Versions diff --git a/ipykernel/debugger.py b/ipykernel/debugger.py index 780d1801..36aced05 100644 --- a/ipykernel/debugger.py +++ b/ipykernel/debugger.py @@ -241,7 +241,7 @@ async def _send_request(self, msg): self.log.debug("DEBUGPYCLIENT:") self.log.debug(self.routing_id) self.log.debug(buf) - await self.debugpy_socket.send_multipart((self.routing_id, buf)) + await self.debugpy_socket.asend_multipart((self.routing_id, buf)) async def _wait_for_response(self): # Since events are never pushed to the message_queue @@ -437,7 +437,7 @@ async def start(self): (self.shell_socket.getsockopt(ROUTING_ID)), ) - msg = await self.shell_socket.recv_multipart() + msg = await self.shell_socket.arecv_multipart() ident, msg = self.session.feed_identities(msg, copy=True) try: msg = self.session.deserialize(msg, content=True, copy=True) diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index c6f8c612..5abb691c 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -54,7 +54,7 @@ class InProcessKernel(IPythonKernel): _underlying_iopub_socket = Instance(DummySocket, (False,)) iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment] - shell_socket = Instance(DummySocket, (True,)) # type:ignore[arg-type] + shell_socket = Instance(DummySocket, (True,)) @default("iopub_thread") def _default_iopub_thread(self): diff --git a/ipykernel/inprocess/session.py b/ipykernel/inprocess/session.py index 0eaed2c6..70b13574 100644 --- a/ipykernel/inprocess/session.py +++ b/ipykernel/inprocess/session.py @@ -3,7 +3,7 @@ class Session(_Session): async def recv(self, socket, copy=True): - return await socket.recv_multipart() + return await socket.arecv_multipart() def send( self, diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index 5a2e0008..d14d0850 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -65,4 +65,8 @@ async def poll(self, timeout=0): return statistics.current_buffer_used != 0 def close(self): - pass + if self.is_shell: + self.in_send_stream.close() + self.in_receive_stream.close() + self.out_send_stream.close() + self.out_receive_stream.close() diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index d8171017..02a0e22a 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -16,13 +16,16 @@ from binascii import b2a_hex from collections import defaultdict, deque from io import StringIO, TextIOBase -from threading import Event, Thread, local +from threading import local from typing import Any, Callable import zmq -from anyio import create_task_group, run, sleep, to_thread +import zmq_anyio +from anyio import sleep from jupyter_client.session import extract_header +from .thread import BaseThread + # ----------------------------------------------------------------------------- # Globals # ----------------------------------------------------------------------------- @@ -37,38 +40,6 @@ # ----------------------------------------------------------------------------- -class _IOPubThread(Thread): - """A thread for a IOPub.""" - - def __init__(self, tasks, **kwargs): - """Initialize the thread.""" - super().__init__(name="IOPub", **kwargs) - self._tasks = tasks - self.pydev_do_not_trace = True - self.is_pydev_daemon_thread = True - self.daemon = True - self.__stop = Event() - - def run(self): - """Run the thread.""" - self.name = "IOPub" - run(self._main) - - async def _main(self): - async with create_task_group() as tg: - for task in self._tasks: - tg.start_soon(task) - await to_thread.run_sync(self.__stop.wait) - tg.cancel_scope.cancel() - - def stop(self): - """Stop the thread. - - This method is threadsafe. - """ - self.__stop.set() - - class IOPubThread: """An object for sending IOPub messages in a background thread @@ -78,7 +49,7 @@ class IOPubThread: whose IO is always run in a thread. """ - def __init__(self, socket, pipe=False): + def __init__(self, socket: zmq_anyio.Socket, pipe=False): """Create IOPub thread Parameters @@ -91,10 +62,7 @@ def __init__(self, socket, pipe=False): """ # ensure all of our sockets as sync zmq.Sockets # don't create async wrappers until we are within the appropriate coroutines - self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket) - if self.socket.context is None: - # bug in pyzmq, shadow socket doesn't always inherit context attribute - self.socket.context = socket.context # type:ignore[unreachable] + self.socket: zmq_anyio.Socket = socket self._context = socket.context self.background_socket = BackgroundSocket(self) @@ -108,14 +76,16 @@ def __init__(self, socket, pipe=False): self._event_pipe_gc_lock: threading.Lock = threading.Lock() self._event_pipe_gc_seconds: float = 10 self._setup_event_pipe() - tasks = [self._handle_event, self._run_event_pipe_gc] + tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start] if pipe: tasks.append(self._handle_pipe_msgs) - self.thread = _IOPubThread(tasks) + self.thread = BaseThread(name="IOPub", daemon=True) + for task in tasks: + self.thread.start_soon(task) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" - self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket) + self._pipe_in0 = self._context.socket(zmq.PULL) self._pipe_in0.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode("ascii") @@ -150,7 +120,7 @@ def _event_pipe(self): except AttributeError: # new thread, new event pipe # create sync base socket - event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket) + event_pipe = self._context.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe @@ -169,30 +139,28 @@ async def _handle_event(self): Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ - # create async wrapper within coroutine - pipe_in = zmq.asyncio.Socket(self._pipe_in0) - try: - while True: - await pipe_in.recv() - # freeze event count so new writes don't extend the queue - # while we are processing - n_events = len(self._events) - for _ in range(n_events): - event_f = self._events.popleft() - event_f() - except Exception: - if self.thread.__stop.is_set(): - return - raise + pipe_in = zmq_anyio.Socket(self._pipe_in0) + async with pipe_in: + try: + while True: + await pipe_in.arecv() + # freeze event count so new writes don't extend the queue + # while we are processing + n_events = len(self._events) + for _ in range(n_events): + event_f = self._events.popleft() + event_f() + except Exception: + if self.thread.stopped.is_set(): + return + raise def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" - ctx = self._context - # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) - self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket) + self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL)) self._pipe_in1.linger = 0 try: @@ -210,18 +178,18 @@ def _setup_pipe_in(self): async def _handle_pipe_msgs(self): """handle pipe messages from a subprocess""" # create async wrapper within coroutine - self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1) - try: - while True: - await self._handle_pipe_msg() - except Exception: - if self.thread.__stop.is_set(): - return - raise + async with self._pipe_in1: + try: + while True: + await self._handle_pipe_msg() + except Exception: + if self.thread.stopped.is_set(): + return + raise async def _handle_pipe_msg(self, msg=None): """handle a pipe message from a subprocess""" - msg = msg or await self._async_pipe_in1.recv_multipart() + msg = msg or await self._pipe_in1.arecv_multipart() if not self._pipe_flag or not self._is_main_process(): return if msg[0] != self._pipe_uuid: diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 48efa6cd..d8d2ba5d 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -12,7 +12,7 @@ from dataclasses import dataclass import comm -import zmq.asyncio +import zmq_anyio from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread from anyio.abc import TaskStatus from IPython.core import release @@ -76,7 +76,7 @@ class IPythonKernel(KernelBase): help="Set this flag to False to deactivate the use of experimental IPython completion APIs.", ).tag(config=True) - debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True) + debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True) user_module = Any() @@ -212,7 +212,8 @@ def __init__(self, **kwargs): } async def process_debugpy(self): - async with create_task_group() as tg: + assert self.debugpy_socket is not None + async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg: tg.start_soon(self.receive_debugpy_messages) tg.start_soon(self.poll_stopped_queue) await to_thread.run_sync(self.debugpy_stop.wait) @@ -235,7 +236,7 @@ async def receive_debugpy_message(self, msg=None): if msg is None: assert self.debugpy_socket is not None - msg = await self.debugpy_socket.recv_multipart() + msg = await self.debugpy_socket.arecv_multipart() # The first frame is the socket id, we can drop it frame = msg[1].decode("utf-8") self.log.debug("Debugpy received: %s", frame) diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 55efaa8e..1cf5697b 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -18,7 +18,7 @@ from pathlib import Path import zmq -import zmq.asyncio +import zmq_anyio from anyio import create_task_group, run from IPython.core.application import ( # type:ignore[attr-defined] BaseIPythonApplication, @@ -325,15 +325,15 @@ def init_sockets(self): """Create a context, a session, and the kernel sockets.""" self.log.info("Starting the kernel at pid: %i", os.getpid()) assert self.context is None, "init_sockets cannot be called twice!" - self.context = context = zmq.asyncio.Context() + self.context = context = zmq.Context() atexit.register(self.close) - self.shell_socket = context.socket(zmq.ROUTER) + self.shell_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.shell_socket.linger = 1000 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port) self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port) - self.stdin_socket = zmq.Context(context).socket(zmq.ROUTER) + self.stdin_socket = context.socket(zmq.ROUTER) self.stdin_socket.linger = 1000 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port) self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port) @@ -349,18 +349,19 @@ def init_sockets(self): def init_control(self, context): """Initialize the control channel.""" - self.control_socket = context.socket(zmq.ROUTER) + self.control_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.control_socket.linger = 1000 self.control_port = self._bind_socket(self.control_socket, self.control_port) self.log.debug("control ROUTER Channel on port: %i" % self.control_port) - self.debugpy_socket = context.socket(zmq.STREAM) + self.debugpy_socket = zmq_anyio.Socket(context, zmq.STREAM) self.debugpy_socket.linger = 1000 - self.debug_shell_socket = context.socket(zmq.DEALER) + self.debug_shell_socket = zmq_anyio.Socket(context.socket(zmq.DEALER)) self.debug_shell_socket.linger = 1000 - if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT): - self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT)) + last_endpoint = self.shell_socket.getsockopt(zmq.LAST_ENDPOINT) + if last_endpoint: + self.debug_shell_socket.connect(last_endpoint) if hasattr(zmq, "ROUTER_HANDOVER"): # set router-handover to workaround zeromq reconnect problems @@ -373,7 +374,7 @@ def init_control(self, context): def init_iopub(self, context): """Initialize the iopub channel.""" - self.iopub_socket = context.socket(zmq.PUB) + self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB)) self.iopub_socket.linger = 1000 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port) self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port) @@ -634,43 +635,6 @@ def configure_tornado_logger(self): handler.setFormatter(formatter) logger.addHandler(handler) - def _init_asyncio_patch(self): - """set default asyncio policy to be compatible with tornado - - Tornado 6 (at least) is not compatible with the default - asyncio implementation on Windows - - Pick the older SelectorEventLoopPolicy on Windows - if the known-incompatible default policy is in use. - - Support for Proactor via a background thread is available in tornado 6.1, - but it is still preferable to run the Selector in the main thread - instead of the background. - - do this as early as possible to make it a low priority and overridable - - ref: https://github.com/tornadoweb/tornado/issues/2608 - - FIXME: if/when tornado supports the defaults in asyncio without threads, - remove and bump tornado requirement for py38. - Most likely, this will mean a new Python version - where asyncio.ProactorEventLoop supports add_reader and friends. - - """ - if sys.platform.startswith("win"): - import asyncio - - try: - from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy - except ImportError: - pass - # not affected - else: - if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy: - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 - # fallback to the pre-3.8 default of Selector - asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) - def init_pdb(self): """Replace pdb with IPython's version that is interruptible. @@ -690,7 +654,6 @@ def init_pdb(self): @catch_config_error def initialize(self, argv=None): """Initialize the application.""" - self._init_asyncio_patch() super().initialize(argv) if self.subapp is not None: return diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index d496e0c9..3ac67130 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -16,6 +16,7 @@ import uuid import warnings from datetime import datetime +from functools import partial from signal import SIGINT, SIGTERM, Signals from .thread import CONTROL_THREAD_NAME @@ -35,6 +36,7 @@ import psutil import zmq +import zmq_anyio from anyio import TASK_STATUS_IGNORED, Event, create_task_group, sleep, to_thread from anyio.abc import TaskStatus from IPython.core.error import StdinNotImplementedError @@ -88,7 +90,7 @@ class Kernel(SingletonConfigurable): session = Instance(Session, allow_none=True) profile_dir = Instance("IPython.core.profiledir.ProfileDir", allow_none=True) - shell_socket = Instance(zmq.asyncio.Socket, allow_none=True) + shell_socket = Instance(zmq_anyio.Socket, allow_none=True) implementation: str implementation_version: str @@ -96,7 +98,7 @@ class Kernel(SingletonConfigurable): _is_test = Bool(False) - control_socket = Instance(zmq.asyncio.Socket, allow_none=True) + control_socket = Instance(zmq_anyio.Socket, allow_none=True) control_tasks: t.Any = List() debug_shell_socket = Any() @@ -267,7 +269,7 @@ async def process_control_message(self, msg=None): assert self.session is not None assert self.control_thread is None or threading.current_thread() == self.control_thread - msg = msg or await self.control_socket.recv_multipart() + msg = msg or await self.control_socket.arecv_multipart() idents, msg = self.session.feed_identities(msg) try: msg = self.session.deserialize(msg, content=True) @@ -364,26 +366,31 @@ async def shell_channel_thread_main(self): assert self.shell_channel_thread is not None assert threading.current_thread() == self.shell_channel_thread - try: - while True: - msg = await self.shell_socket.recv_multipart(copy=False) - # deserialize only the header to get subshell_id - # Keep original message to send to subshell_id unmodified. - _, msg2 = self.session.feed_identities(msg, copy=False) - try: - msg3 = self.session.deserialize(msg2, content=False, copy=False) - subshell_id = msg3["header"].get("subshell_id") - - # Find inproc pair socket to use to send message to correct subshell. - socket = self.shell_channel_thread.manager.get_shell_channel_socket(subshell_id) - assert socket is not None - socket.send_multipart(msg, copy=False) - except Exception: - self.log.error("Invalid message", exc_info=True) # noqa: G201 - except BaseException: - if self.shell_stop.is_set(): - return - raise + async with self.shell_socket, create_task_group() as tg: + try: + while True: + msg = await self.shell_socket.arecv_multipart(copy=False) + # deserialize only the header to get subshell_id + # Keep original message to send to subshell_id unmodified. + _, msg2 = self.session.feed_identities(msg, copy=False) + try: + msg3 = self.session.deserialize(msg2, content=False, copy=False) + subshell_id = msg3["header"].get("subshell_id") + + # Find inproc pair socket to use to send message to correct subshell. + socket = self.shell_channel_thread.manager.get_shell_channel_socket( + subshell_id + ) + assert socket is not None + if not socket.started.is_set(): + await tg.start(socket.start) + await socket.asend_multipart(msg, copy=False) + except Exception: + self.log.error("Invalid message", exc_info=True) # noqa: G201 + except BaseException: + if self.shell_stop.is_set(): + return + raise async def shell_main(self, subshell_id: str | None): """Main loop for a single subshell.""" @@ -411,13 +418,15 @@ async def shell_main(self, subshell_id: str | None): async def process_shell(self, socket=None): # socket=None is valid if kernel subshells are not supported. - try: - while True: - await self.process_shell_message(socket=socket) - except BaseException: - if self.shell_stop.is_set(): - return - raise + _socket = self.shell_socket if socket is None else socket + async with _socket: + try: + while True: + await self.process_shell_message(socket=socket) + except BaseException: + if self.shell_stop.is_set(): + return + raise async def process_shell_message(self, msg=None, socket=None): # If socket is None kernel subshells are not supported so use socket=shell_socket. @@ -435,8 +444,8 @@ async def process_shell_message(self, msg=None, socket=None): assert socket is None socket = self.shell_socket - no_msg = msg is None if self._is_test else not await socket.poll(0) - msg = msg or await socket.recv_multipart(copy=False) + no_msg = msg is None if self._is_test else not await socket.apoll(0) + msg = msg or await socket.arecv_multipart(copy=False) received_time = time.monotonic() copy = not isinstance(msg[0], zmq.Message) @@ -490,7 +499,7 @@ async def process_shell_message(self, msg=None, socket=None): try: result = handler(socket, idents, msg) if inspect.isawaitable(result): - await result + result = await result except Exception: self.log.error("Exception in message handler:", exc_info=True) # noqa: G201 except KeyboardInterrupt: @@ -509,7 +518,8 @@ async def process_shell_message(self, msg=None, socket=None): self._publish_status("idle", "shell") async def control_main(self): - async with create_task_group() as tg: + assert self.control_socket is not None + async with self.control_socket, create_task_group() as tg: for task in self.control_tasks: tg.start_soon(task) tg.start_soon(self.process_control) @@ -529,7 +539,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: self.control_stop = threading.Event() if not self._is_test and self.control_socket is not None: if self.control_thread: - self.control_thread.add_task(self.control_main) + self.control_thread.start_soon(self.control_main) self.control_thread.start() else: tg.start_soon(self.control_main) @@ -544,9 +554,11 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: # Assign tasks to and start shell channel thread. manager = self.shell_channel_thread.manager - self.shell_channel_thread.add_task(self.shell_channel_thread_main) - self.shell_channel_thread.add_task(manager.listen_from_control, self.shell_main) - self.shell_channel_thread.add_task(manager.listen_from_subshells) + self.shell_channel_thread.start_soon(self.shell_channel_thread_main) + self.shell_channel_thread.start_soon( + partial(manager.listen_from_control, self.shell_main) + ) + self.shell_channel_thread.start_soon(manager.listen_from_subshells) self.shell_channel_thread.start() else: if not self._is_test and self.shell_socket is not None: @@ -1075,9 +1087,11 @@ async def create_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "create"}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "create"}) + reply = await other_socket.arecv_json() self.session.send(socket, "create_subshell_reply", reply, parent, ident) @@ -1097,9 +1111,11 @@ async def delete_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "delete", "subshell_id": subshell_id}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "delete", "subshell_id": subshell_id}) + reply = await other_socket.arecv_json() self.session.send(socket, "delete_subshell_reply", reply, parent, ident) @@ -1112,9 +1128,11 @@ async def list_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. - other_socket = self.shell_channel_thread.manager.get_control_other_socket() - await other_socket.send_json({"type": "list"}) - reply = await other_socket.recv_json() + other_socket = await self.shell_channel_thread.manager.get_control_other_socket( + self.control_thread + ) + await other_socket.asend_json({"type": "list"}) + reply = await other_socket.arecv_json() self.session.send(socket, "list_subshell_reply", reply, parent, ident) diff --git a/ipykernel/shellchannel.py b/ipykernel/shellchannel.py index bc0459c4..819a0aec 100644 --- a/ipykernel/shellchannel.py +++ b/ipykernel/shellchannel.py @@ -1,5 +1,6 @@ """A thread for a shell channel.""" -import zmq.asyncio +import zmq +import zmq_anyio from .subshell_manager import SubshellManager from .thread import SHELL_CHANNEL_THREAD_NAME, BaseThread @@ -11,7 +12,12 @@ class ShellChannelThread(BaseThread): Communicates with shell/subshell threads via pairs of ZMQ inproc sockets. """ - def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket, **kwargs): + def __init__( + self, + context: zmq.Context, # type: ignore[type-arg] + shell_socket: zmq_anyio.Socket, + **kwargs, + ): """Initialize the thread.""" super().__init__(name=SHELL_CHANNEL_THREAD_NAME, **kwargs) self._manager: SubshellManager | None = None diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py index 18e15ab3..180e9ecb 100644 --- a/ipykernel/subshell.py +++ b/ipykernel/subshell.py @@ -2,7 +2,8 @@ from threading import current_thread -import zmq.asyncio +import zmq +import zmq_anyio from .thread import BaseThread @@ -15,17 +16,22 @@ def __init__(self, subshell_id: str, **kwargs): super().__init__(name=f"subshell-{subshell_id}", **kwargs) # Inproc PAIR socket, for communication with shell channel thread. - self._pair_socket: zmq.asyncio.Socket | None = None + self._pair_socket: zmq_anyio.Socket | None = None - async def create_pair_socket(self, context: zmq.asyncio.Context, address: str) -> None: + async def create_pair_socket( + self, + context: zmq.Context, # type: ignore[type-arg] + address: str, + ) -> None: """Create inproc PAIR socket, for communication with shell channel thread. - Should be called from this thread, so usually via add_task before the + Should be called from this thread, so usually via start_soon before the thread is started. """ assert current_thread() == self - self._pair_socket = context.socket(zmq.PAIR) + self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR) self._pair_socket.connect(address) + self.start_soon(self._pair_socket.start) def run(self) -> None: try: diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index 805d6f81..2120abe1 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -7,20 +7,22 @@ import typing as t import uuid from dataclasses import dataclass +from functools import partial from threading import Lock, current_thread, main_thread import zmq -import zmq.asyncio +import zmq_anyio from anyio import create_memory_object_stream, create_task_group +from anyio.abc import TaskGroup from .subshell import SubshellThread -from .thread import SHELL_CHANNEL_THREAD_NAME +from .thread import SHELL_CHANNEL_THREAD_NAME, BaseThread @dataclass class Subshell: thread: SubshellThread - shell_channel_socket: zmq.asyncio.Socket + shell_channel_socket: zmq_anyio.Socket class SubshellManager: @@ -38,10 +40,14 @@ class SubshellManager: against multiple subshells attempting to send at the same time. """ - def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socket): + def __init__( + self, + context: zmq.Context, # type: ignore[type-arg] + shell_socket: zmq_anyio.Socket, + ): assert current_thread() == main_thread() - self._context: zmq.asyncio.Context = context + self._context: zmq.Context = context # type: ignore[type-arg] self._shell_socket = shell_socket self._cache: dict[str, Subshell] = {} self._lock_cache = Lock() @@ -50,15 +56,39 @@ def __init__(self, context: zmq.asyncio.Context, shell_socket: zmq.asyncio.Socke # Inproc pair sockets for control channel and main shell (parent subshell). # Each inproc pair has a "shell_channel" socket used in the shell channel # thread, and an "other" socket used in the other thread. - self._control_shell_channel_socket = self._create_inproc_pair_socket("control", True) - self._control_other_socket = self._create_inproc_pair_socket("control", False) - self._parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) - self._parent_other_socket = self._create_inproc_pair_socket(None, False) + self.__control_shell_channel_socket: zmq_anyio.Socket | None = None + self.__control_other_socket: zmq_anyio.Socket | None = None + self.__parent_shell_channel_socket: zmq_anyio.Socket | None = None + self.__parent_other_socket: zmq_anyio.Socket | None = None # anyio memory object stream for async queue-like communication between tasks. # Used by _create_subshell to tell listen_from_subshells to spawn a new task. self._send_stream, self._receive_stream = create_memory_object_stream[str]() + @property + def _control_shell_channel_socket(self) -> zmq_anyio.Socket: + if self.__control_shell_channel_socket is None: + self.__control_shell_channel_socket = self._create_inproc_pair_socket("control", True) + return self.__control_shell_channel_socket + + @property + def _control_other_socket(self) -> zmq_anyio.Socket: + if self.__control_other_socket is None: + self.__control_other_socket = self._create_inproc_pair_socket("control", False) + return self.__control_other_socket + + @property + def _parent_shell_channel_socket(self) -> zmq_anyio.Socket: + if self.__parent_shell_channel_socket is None: + self.__parent_shell_channel_socket = self._create_inproc_pair_socket(None, True) + return self.__parent_shell_channel_socket + + @property + def _parent_other_socket(self) -> zmq_anyio.Socket: + if self.__parent_other_socket is None: + self.__parent_other_socket = self._create_inproc_pair_socket(None, False) + return self.__parent_other_socket + def close(self) -> None: """Stop all subshells and close all resources.""" assert current_thread().name == SHELL_CHANNEL_THREAD_NAME @@ -67,10 +97,10 @@ def close(self) -> None: self._receive_stream.close() for socket in ( - self._control_shell_channel_socket, - self._control_other_socket, - self._parent_shell_channel_socket, - self._parent_other_socket, + self.__control_shell_channel_socket, + self.__control_other_socket, + self.__parent_shell_channel_socket, + self.__parent_other_socket, ): if socket is not None: socket.close() @@ -83,10 +113,13 @@ def close(self) -> None: break self._stop_subshell(subshell) - def get_control_other_socket(self) -> zmq.asyncio.Socket: + async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket: + if not self._control_other_socket.started.is_set(): + thread.task_group.start_soon(self._control_other_socket.start) + await self._control_other_socket.started.wait() return self._control_other_socket - def get_other_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + def get_other_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: """Return the other inproc pair socket for a subshell. This socket is accessed from the subshell thread. @@ -98,7 +131,7 @@ def get_other_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: assert socket is not None return socket - def get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + def get_shell_channel_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: """Return the shell channel inproc pair socket for a subshell. This socket is accessed from the shell channel thread. @@ -123,10 +156,11 @@ async def listen_from_control(self, subshell_task: t.Any) -> None: assert current_thread().name == SHELL_CHANNEL_THREAD_NAME socket = self._control_shell_channel_socket - while True: - request = await socket.recv_json() # type: ignore[misc] - reply = await self._process_control_request(request, subshell_task) - await socket.send_json(reply) # type: ignore[func-returns-value] + async with socket: + while True: + request = await socket.arecv_json() + reply = await self._process_control_request(request, subshell_task) + await socket.asend_json(reply) async def listen_from_subshells(self) -> None: """Listen for reply messages on inproc sockets of all subshells and resend @@ -137,9 +171,9 @@ async def listen_from_subshells(self) -> None: assert current_thread().name == SHELL_CHANNEL_THREAD_NAME async with create_task_group() as tg: - tg.start_soon(self._listen_for_subshell_reply, None) + tg.start_soon(self._listen_for_subshell_reply, None, tg) async for subshell_id in self._receive_stream: - tg.start_soon(self._listen_for_subshell_reply, subshell_id) + tg.start_soon(self._listen_for_subshell_reply, subshell_id, tg) def subshell_id_from_thread_id(self, thread_id: int) -> str | None: """Return subshell_id of the specified thread_id. @@ -159,10 +193,10 @@ def subshell_id_from_thread_id(self, thread_id: int) -> str | None: def _create_inproc_pair_socket( self, name: str | None, shell_channel_end: bool - ) -> zmq.asyncio.Socket: + ) -> zmq_anyio.Socket: """Create and return a single ZMQ inproc pair socket.""" address = self._get_inproc_socket_address(name) - socket = self._context.socket(zmq.PAIR) + socket = zmq_anyio.Socket(self._context, zmq.PAIR) if shell_channel_end: socket.bind(address) else: @@ -186,8 +220,8 @@ async def _create_subshell(self, subshell_task: t.Any) -> str: await self._send_stream.send(subshell_id) address = self._get_inproc_socket_address(subshell_id) - thread.add_task(thread.create_pair_socket, self._context, address) - thread.add_task(subshell_task, subshell_id) + thread.start_soon(partial(thread.create_pair_socket, self._context, address)) + thread.start_soon(partial(subshell_task, subshell_id)) thread.start() return subshell_id @@ -208,7 +242,7 @@ def _get_inproc_socket_address(self, name: str | None) -> str: full_name = f"subshell-{name}" if name else "subshell" return f"inproc://{full_name}" - def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq.asyncio.Socket: + def _get_shell_channel_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: if subshell_id is None: return self._parent_shell_channel_socket with self._lock_cache: @@ -220,7 +254,9 @@ def _is_subshell(self, subshell_id: str | None) -> bool: with self._lock_cache: return subshell_id in self._cache - async def _listen_for_subshell_reply(self, subshell_id: str | None) -> None: + async def _listen_for_subshell_reply( + self, subshell_id: str | None, task_group: TaskGroup + ) -> None: """Listen for reply messages on specified subshell inproc socket and resend to the client via the shell_socket. @@ -230,11 +266,13 @@ async def _listen_for_subshell_reply(self, subshell_id: str | None) -> None: shell_channel_socket = self._get_shell_channel_socket(subshell_id) + task_group.start_soon(shell_channel_socket.start) + await shell_channel_socket.started.wait() try: while True: - msg = await shell_channel_socket.recv_multipart(copy=False) + msg = await shell_channel_socket.arecv_multipart(copy=False) with self._lock_shell_socket: - await self._shell_socket.send_multipart(msg) + await self._shell_socket.asend_multipart(msg) except BaseException: if not self._is_subshell(subshell_id): # Subshell no longer exists so exit gracefully diff --git a/ipykernel/thread.py b/ipykernel/thread.py index 40509ece..dc68bb3b 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -1,8 +1,13 @@ """Base class for threads.""" -import typing as t +from __future__ import annotations + +from collections.abc import Awaitable +from queue import Queue from threading import Event, Thread +from typing import Any, Callable from anyio import create_task_group, run, to_thread +from anyio.abc import TaskGroup CONTROL_THREAD_NAME = "Control" SHELL_CHANNEL_THREAD_NAME = "Shell channel" @@ -14,24 +19,53 @@ class BaseThread(Thread): def __init__(self, **kwargs): """Initialize the thread.""" super().__init__(**kwargs) + self.started = Event() + self.stopped = Event() self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True - self.__stop = Event() - self._tasks_and_args: list[tuple[t.Any, t.Any]] = [] + self._tasks: Queue[tuple[str, Callable[[], Awaitable[Any]]] | None] = Queue() + self._result: Queue[Any] = Queue() + + @property + def task_group(self) -> TaskGroup: + return self._task_group + + def start_soon(self, coro: Callable[[], Awaitable[Any]]) -> None: + self._tasks.put(("start_soon", coro)) - def add_task(self, task: t.Any, *args: t.Any) -> None: - # May only add tasks before the thread is started. - self._tasks_and_args.append((task, args)) + def run_async(self, coro: Callable[[], Awaitable[Any]]) -> Any: + self._tasks.put(("run_async", coro)) + return self._result.get() - def run(self) -> t.Any: + def run_sync(self, func: Callable[..., Any]) -> Any: + self._tasks.put(("run_sync", func)) + return self._result.get() + + def run(self) -> None: """Run the thread.""" - return run(self._main) + try: + run(self._main) + except Exception: + pass async def _main(self) -> None: async with create_task_group() as tg: - for task, args in self._tasks_and_args: - tg.start_soon(task, *args) - await to_thread.run_sync(self.__stop.wait) + self._task_group = tg + self.started.set() + while True: + task = await to_thread.run_sync(self._tasks.get) + if task is None: + break + func, arg = task + if func == "start_soon": + tg.start_soon(arg) + elif func == "run_async": + res = await arg + self._result.put(res) + else: # func == "run_sync" + res = arg() + self._result.put(res) + tg.cancel_scope.cancel() def stop(self) -> None: @@ -39,4 +73,5 @@ def stop(self) -> None: This method is threadsafe. """ - self.__stop.set() + self._tasks.put(None) + self.stopped.set() diff --git a/pyproject.toml b/pyproject.toml index 675d9d87..fb4eff50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,10 +30,11 @@ dependencies = [ "nest_asyncio>=1.4", "matplotlib-inline>=0.1", 'appnope>=0.1.2;platform_system=="Darwin"', - "pyzmq>=25.0", + "pyzmq>=26.0", "psutil>=5.7", "packaging>=22", "anyio>=4.2.0", + "zmq-anyio >=0.2.4", ] [project.urls] @@ -62,7 +63,6 @@ test = [ "pre-commit", "pytest-timeout", "trio", - "pytest-asyncio>=0.23.5", ] cov = [ "coverage[toml]", @@ -191,6 +191,10 @@ filterwarnings= [ # ignore unclosed sqlite in traits "ignore:unclosed database in .trigger_timeout' was never awaited", + "ignore: Unclosed socket" ] [tool.coverage.report] @@ -315,3 +319,6 @@ ignore = ["W002"] [tool.repo-review] ignore = ["PY007", "PP308", "GH102", "MY101"] + +[tool.hatch.metadata] +allow-direct-references = true diff --git a/tests/conftest.py b/tests/conftest.py index 2c266555..db992b74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,13 @@ -import asyncio import logging -import os from math import inf +from threading import Event from typing import Any, Callable, no_type_check from unittest.mock import MagicMock import pytest import zmq -import zmq.asyncio -from anyio import create_memory_object_stream, create_task_group +import zmq_anyio +from anyio import create_memory_object_stream, create_task_group, sleep from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from jupyter_client.session import Session @@ -23,11 +22,6 @@ resource = None # type:ignore -@pytest.fixture() -def anyio_backend(): - return "asyncio" - - pytestmark = pytest.mark.anyio @@ -46,11 +40,6 @@ def anyio_backend(): resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard)) -# Enforce selector event loop on Windows. -if os.name == "nt": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore - - class TestSession(Session): """A session that copies sent messages to an internal stream, so that they can be accessed later. @@ -77,21 +66,21 @@ def send(self, socket, *args, **kwargs): class KernelMixin: - shell_socket: zmq.asyncio.Socket - control_socket: zmq.asyncio.Socket + shell_socket: zmq_anyio.Socket + control_socket: zmq_anyio.Socket stop: Callable[[], None] log = logging.getLogger() def _initialize(self): self._is_test = True - self.context = context = zmq.asyncio.Context() - self.iopub_socket = context.socket(zmq.PUB) - self.stdin_socket = context.socket(zmq.ROUTER) + self.context = context = zmq.Context() + self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB)) + self.stdin_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.test_sockets = [self.iopub_socket] for name in ["shell", "control"]: - socket = context.socket(zmq.ROUTER) + socket = zmq_anyio.Socket(context.socket(zmq.ROUTER)) self.test_sockets.append(socket) setattr(self, f"{name}_socket", socket) @@ -142,7 +131,7 @@ def _prep_msg(self, *args, **kwargs): async def _wait_for_msg(self): while not self._reply: - await asyncio.sleep(0.1) + await sleep(0.1) _, msg = self.session.feed_identities(self._reply) return self.session.deserialize(msg) @@ -166,6 +155,8 @@ class MockKernel(KernelMixin, Kernel): # type:ignore def __init__(self, *args, **kwargs): self._initialize() self.shell = MagicMock() + self.shell_stop = Event() + self.control_stop = Event() super().__init__(*args, **kwargs) def do_execute( @@ -187,6 +178,8 @@ def do_execute( class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore def __init__(self, *args, **kwargs): self._initialize() + self.shell_stop = Event() + self.control_stop = Event() super().__init__(*args, **kwargs) diff --git a/tests/test_async.py b/tests/test_async.py index a40db4a0..c2dd980b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -30,24 +30,23 @@ def test_async_await(): assert content["status"] == "ok", content -# FIXME: @pytest.mark.parametrize("asynclib", ["asyncio", "trio", "curio"]) @pytest.mark.skipif(os.name == "nt", reason="Cannot interrupt on Windows") -@pytest.mark.parametrize("asynclib", ["asyncio"]) -def test_async_interrupt(asynclib, request): +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) # FIXME: %autoawait trio +def test_async_interrupt(anyio_backend, request): assert KC is not None assert KM is not None try: - __import__(asynclib) + __import__(anyio_backend) except ImportError: - pytest.skip("Requires %s" % asynclib) - request.addfinalizer(lambda: execute("%autoawait asyncio", KC)) + pytest.skip("Requires %s" % anyio_backend) + request.addfinalizer(lambda: execute(f"%autoawait {anyio_backend}", KC)) flush_channels(KC) - msg_id, content = execute("%autoawait " + asynclib, KC) + msg_id, content = execute(f"%autoawait {anyio_backend}", KC) assert content["status"] == "ok", content flush_channels(KC) - msg_id = KC.execute(f"print('begin'); import {asynclib}; await {asynclib}.sleep(5)") + msg_id = KC.execute(f"print('begin'); import {anyio_backend}; await {anyio_backend}.sleep(5)") busy = KC.get_iopub_msg(timeout=TIMEOUT) validate_message(busy, "status", msg_id) assert busy["content"]["execution_state"] == "busy" diff --git a/tests/test_eventloop.py b/tests/test_eventloop.py index 62a7f8ba..fcaa2bde 100644 --- a/tests/test_eventloop.py +++ b/tests/test_eventloop.py @@ -85,6 +85,7 @@ def do_thing(): @windows_skip +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) def test_asyncio_loop(kernel): def do_thing(): loop.call_later(0.01, loop.stop) diff --git a/tests/test_io.py b/tests/test_io.py index e3ff2815..aca2694e 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,31 +12,31 @@ import pytest import zmq -import zmq.asyncio +import zmq_anyio from jupyter_client.session import Session from ipykernel.iostream import _PARENT, BackgroundSocket, IOPubThread, OutStream +pytestmark = pytest.mark.anyio + @pytest.fixture() def ctx(): - ctx = zmq.asyncio.Context() + ctx = zmq.Context() yield ctx ctx.destroy() @pytest.fixture() -def iopub_thread(ctx): - with ctx.socket(zmq.PUB) as pub: +async def iopub_thread(ctx): + async with zmq_anyio.Socket(ctx.socket(zmq.PUB)) as pub: thread = IOPubThread(pub) thread.start() yield thread - thread.stop() - thread.close() -def test_io_api(iopub_thread): +async def test_io_api(iopub_thread): """Test that wrapped stdout has the same API as a normal TextIO object""" session = Session() stream = OutStream(session, iopub_thread, "stdout") @@ -59,13 +59,13 @@ def test_io_api(iopub_thread): stream.write(b"") # type:ignore -def test_io_isatty(iopub_thread): +async def test_io_isatty(iopub_thread): session = Session() stream = OutStream(session, iopub_thread, "stdout", isatty=True) assert stream.isatty() -async def test_io_thread(anyio_backend, iopub_thread): +async def test_io_thread(iopub_thread): thread = iopub_thread thread._setup_pipe_in() msg = [thread._pipe_uuid, b"a"] @@ -77,11 +77,9 @@ async def test_io_thread(anyio_backend, iopub_thread): thread._really_send([b"hi"]) ctx1.destroy() thread.stop() - thread.close() - thread._really_send(None) -async def test_background_socket(anyio_backend, iopub_thread): +async def test_background_socket(iopub_thread): sock = BackgroundSocket(iopub_thread) assert sock.__class__ == BackgroundSocket with warnings.catch_warnings(): @@ -92,7 +90,7 @@ async def test_background_socket(anyio_backend, iopub_thread): sock.send(b"hi") -async def test_outstream(anyio_backend, iopub_thread): +async def test_outstream(iopub_thread): session = Session() pub = iopub_thread.socket @@ -118,7 +116,6 @@ async def test_outstream(anyio_backend, iopub_thread): assert stream.writable() -@pytest.mark.anyio() async def test_event_pipe_gc(iopub_thread): session = Session(key=b"abc") stream = OutStream( @@ -150,61 +147,61 @@ async def test_event_pipe_gc(iopub_thread): # assert iopub_thread._event_pipes == {} -def subprocess_test_echo_watch(): +async def subprocess_test_echo_watch(): # handshake Pub subscription session = Session(key=b"abc") # use PUSH socket to avoid subscription issues - with zmq.asyncio.Context() as ctx, ctx.socket(zmq.PUSH) as pub: - pub.connect(os.environ["IOPUB_URL"]) - iopub_thread = IOPubThread(pub) - iopub_thread.start() - stdout_fd = sys.stdout.fileno() - sys.stdout.flush() - stream = OutStream( - session, - iopub_thread, - "stdout", - isatty=True, - echo=sys.stdout, - watchfd="force", - ) - save_stdout = sys.stdout - with stream, mock.patch.object(sys, "stdout", stream): - # write to low-level FD - os.write(stdout_fd, b"fd\n") - # print (writes to stream) - print("print\n", end="") + with zmq.Context() as ctx: + async with zmq_anyio.Socket(ctx.socket(zmq.PUSH)) as pub: + pub.connect(os.environ["IOPUB_URL"]) + iopub_thread = IOPubThread(pub) + iopub_thread.start() + stdout_fd = sys.stdout.fileno() sys.stdout.flush() - # write to unwrapped __stdout__ (should also go to original FD) - sys.__stdout__.write("__stdout__\n") - sys.__stdout__.flush() - # write to original sys.stdout (should be the same as __stdout__) - save_stdout.write("stdout\n") - save_stdout.flush() - # is there another way to flush on the FD? - fd_file = os.fdopen(stdout_fd, "w") - fd_file.flush() - # we don't have a sync flush on _reading_ from the watched pipe - time.sleep(1) - stream.flush() - iopub_thread.stop() - iopub_thread.close() - - -@pytest.mark.anyio() + stream = OutStream( + session, + iopub_thread, + "stdout", + isatty=True, + echo=sys.stdout, + watchfd="force", + ) + save_stdout = sys.stdout + with stream, mock.patch.object(sys, "stdout", stream): + # write to low-level FD + os.write(stdout_fd, b"fd\n") + # print (writes to stream) + print("print\n", end="") + sys.stdout.flush() + # write to unwrapped __stdout__ (should also go to original FD) + sys.__stdout__.write("__stdout__\n") + sys.__stdout__.flush() + # write to original sys.stdout (should be the same as __stdout__) + save_stdout.write("stdout\n") + save_stdout.flush() + # is there another way to flush on the FD? + fd_file = os.fdopen(stdout_fd, "w") + fd_file.flush() + # we don't have a sync flush on _reading_ from the watched pipe + time.sleep(1) + stream.flush() + iopub_thread.stop() + iopub_thread.close() + + @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows") async def test_echo_watch(ctx): """Test echo on underlying FD while capturing the same FD Test runs in a subprocess to avoid messing with pytest output capturing. """ - s = ctx.socket(zmq.PULL) + s = zmq_anyio.Socket(ctx.socket(zmq.PULL)) port = s.bind_to_random_port("tcp://127.0.0.1") url = f"tcp://127.0.0.1:{port}" session = Session(key=b"abc") stdout_chunks = [] - with s: + async with s: env = dict(os.environ) env["IOPUB_URL"] = url env["PYTHONUNBUFFERED"] = "1" @@ -213,7 +210,7 @@ async def test_echo_watch(ctx): [ sys.executable, "-c", - f"import {__name__}; {__name__}.subprocess_test_echo_watch()", + f"import {__name__}, anyio; anyio.run({__name__}.subprocess_test_echo_watch)", ], env=env, capture_output=True, @@ -224,8 +221,8 @@ async def test_echo_watch(ctx): print(f"{p.stdout=}") print(f"{p.stderr}=", file=sys.stderr) assert p.returncode == 0 - while await s.poll(timeout=100): - msg = await s.recv_multipart() + while await s.apoll(timeout=100): + msg = await s.arecv_multipart() ident, msg = session.feed_identities(msg, copy=True) msg = session.deserialize(msg, content=True, copy=True) assert msg is not None # for type narrowing diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 8efc3dcc..e27cb0d7 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -62,7 +62,7 @@ def test_simple_print(): def test_print_to_correct_cell_from_thread(): """should print to the cell that spawned the thread, not a subsequently run cell""" iterations = 5 - interval = 0.25 + interval = 1 code = f"""\ from threading import Thread from time import sleep @@ -83,6 +83,8 @@ def thread_target(): msg = kc.get_iopub_msg(timeout=interval * 2) if msg["msg_type"] != "stream": continue + print(f"{thread_msg_id=}") + print(f"{msg=}") content = msg["content"] assert content["name"] == "stdout" assert content["text"] == str(received) @@ -94,7 +96,7 @@ def thread_target(): def test_print_to_correct_cell_from_child_thread(): """should print to the cell that spawned the thread, not a subsequently run cell""" iterations = 5 - interval = 0.25 + interval = 1 code = f"""\ from threading import Thread from time import sleep @@ -105,8 +107,8 @@ def child_target(): sleep({interval}) def parent_target(): - sleep({interval}) Thread(target=child_target).start() + sleep({interval * iterations}) Thread(target=parent_target).start() """ @@ -119,6 +121,8 @@ def parent_target(): msg = kc.get_iopub_msg(timeout=interval * 2) if msg["msg_type"] != "stream": continue + print(f"{thread_msg_id=}") + print(f"{msg=}") content = msg["content"] assert content["name"] == "stdout" assert content["text"] == str(received) @@ -130,7 +134,7 @@ def parent_target(): def test_print_to_correct_cell_from_asyncio(): """should print to the cell that scheduled the task, not a subsequently run cell""" iterations = 5 - interval = 0.25 + interval = 1 code = f"""\ import asyncio @@ -151,6 +155,8 @@ async def async_task(): msg = kc.get_iopub_msg(timeout=interval * 2) if msg["msg_type"] != "stream": continue + print(f"{thread_msg_id=}") + print(f"{msg=}") content = msg["content"] assert content["name"] == "stdout" assert content["text"] == str(received)