diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index 2636d157..b9dea456 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -91,7 +91,7 @@ def close(self) -> None: async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket: if not self._control_other_socket.started.is_set(): - thread.start_soon(self._control_other_socket.start) + thread.task_group.start_soon(self._control_other_socket.start) await self._control_other_socket.started.wait() return self._control_other_socket diff --git a/ipykernel/thread.py b/ipykernel/thread.py index d853a2ad..dc68bb3b 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -4,9 +4,10 @@ from collections.abc import Awaitable from queue import Queue from threading import Event, Thread -from typing import Callable +from typing import Any, Callable from anyio import create_task_group, run, to_thread +from anyio.abc import TaskGroup CONTROL_THREAD_NAME = "Control" SHELL_CHANNEL_THREAD_NAME = "Shell channel" @@ -22,10 +23,23 @@ def __init__(self, **kwargs): self.stopped = Event() self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True - self._tasks: Queue[Callable[[], Awaitable[None]] | None] = Queue() + self._tasks: Queue[tuple[str, Callable[[], Awaitable[Any]]] | None] = Queue() + self._result: Queue[Any] = Queue() - def start_soon(self, task: Callable[[], Awaitable[None]] | None) -> None: - self._tasks.put(task) + @property + def task_group(self) -> TaskGroup: + return self._task_group + + def start_soon(self, coro: Callable[[], Awaitable[Any]]) -> None: + self._tasks.put(("start_soon", coro)) + + def run_async(self, coro: Callable[[], Awaitable[Any]]) -> Any: + self._tasks.put(("run_async", coro)) + return self._result.get() + + def run_sync(self, func: Callable[..., Any]) -> Any: + self._tasks.put(("run_sync", func)) + return self._result.get() def run(self) -> None: """Run the thread.""" @@ -36,12 +50,22 @@ def run(self) -> None: async def _main(self) -> None: async with create_task_group() as tg: + self._task_group = tg self.started.set() while True: task = await to_thread.run_sync(self._tasks.get) if task is None: break - tg.start_soon(task) + func, arg = task + if func == "start_soon": + tg.start_soon(arg) + elif func == "run_async": + res = await arg + self._result.put(res) + else: # func == "run_sync" + res = arg() + self._result.put(res) + tg.cancel_scope.cancel() def stop(self) -> None: