From 130387c2a213aa86880d01788a91011265fb7de6 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 15 Nov 2024 09:57:18 +0100 Subject: [PATCH] Replace thread add_task with start_soon --- ipykernel/kernelbase.py | 17 +++++++++-------- ipykernel/shellchannel.py | 2 +- ipykernel/subshell.py | 4 ++-- ipykernel/subshell_manager.py | 14 ++++++-------- ipykernel/thread.py | 35 +++++++++++++++++------------------ 5 files changed, 35 insertions(+), 37 deletions(-) diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 3c7324d2..fcceef26 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -16,6 +16,7 @@ import uuid import warnings from datetime import datetime +from functools import partial from signal import SIGINT, SIGTERM, Signals from .thread import CONTROL_THREAD_NAME @@ -536,7 +537,7 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: self.control_stop = threading.Event() if not self._is_test and self.control_socket is not None: if self.control_thread: - self.control_thread.add_task(self.control_main) + self.control_thread.start_soon(self.control_main) self.control_thread.start() else: tg.start_soon(self.control_main) @@ -551,11 +552,11 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: # Assign tasks to and start shell channel thread. manager = self.shell_channel_thread.manager - self.shell_channel_thread.add_task(self.shell_channel_thread_main) - self.shell_channel_thread.add_task( - manager.listen_from_control, self.shell_main, self.shell_channel_thread + self.shell_channel_thread.start_soon(self.shell_channel_thread_main) + self.shell_channel_thread.start_soon( + partial(manager.listen_from_control, self.shell_main, self.shell_channel_thread) ) - self.shell_channel_thread.add_task(manager.listen_from_subshells) + self.shell_channel_thread.start_soon(manager.listen_from_subshells) self.shell_channel_thread.start() else: if not self._is_test and self.shell_socket is not None: @@ -1085,7 +1086,7 @@ async def create_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. other_socket = await self.shell_channel_thread.manager.get_control_other_socket( - self.control_thread.get_task_group() + self.control_thread ) await other_socket.asend_json({"type": "create"}) reply = await other_socket.arecv_json() @@ -1109,7 +1110,7 @@ async def delete_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. other_socket = await self.shell_channel_thread.manager.get_control_other_socket( - self.control_thread.get_task_group() + self.control_thread ) await other_socket.asend_json({"type": "delete", "subshell_id": subshell_id}) reply = await other_socket.arecv_json() @@ -1126,7 +1127,7 @@ async def list_subshell_request(self, socket, ident, parent) -> None: # This should only be called in the control thread if it exists. # Request is passed to shell channel thread to process. other_socket = await self.shell_channel_thread.manager.get_control_other_socket( - self.control_thread.get_task_group() + self.control_thread ) await other_socket.asend_json({"type": "list"}) reply = await other_socket.arecv_json() diff --git a/ipykernel/shellchannel.py b/ipykernel/shellchannel.py index 789a8875..819a0aec 100644 --- a/ipykernel/shellchannel.py +++ b/ipykernel/shellchannel.py @@ -28,7 +28,7 @@ def __init__( def manager(self) -> SubshellManager: # Lazy initialisation. if self._manager is None: - self._manager = SubshellManager(self._context, self._shell_socket, self.get_task_group) + self._manager = SubshellManager(self._context, self._shell_socket) return self._manager def run(self) -> None: diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py index e84f5498..180e9ecb 100644 --- a/ipykernel/subshell.py +++ b/ipykernel/subshell.py @@ -25,13 +25,13 @@ async def create_pair_socket( ) -> None: """Create inproc PAIR socket, for communication with shell channel thread. - Should be called from this thread, so usually via add_task before the + Should be called from this thread, so usually via start_soon before the thread is started. """ assert current_thread() == self self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR) self._pair_socket.connect(address) - self.add_task(self._pair_socket.start) + self.start_soon(self._pair_socket.start) def run(self) -> None: try: diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index dbd3da76..505c2f40 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -7,8 +7,8 @@ import typing as t import uuid from dataclasses import dataclass +from functools import partial from threading import Lock, current_thread, main_thread -from typing import Callable import zmq import zmq_anyio @@ -44,13 +44,11 @@ def __init__( self, context: zmq.Context, # type: ignore[type-arg] shell_socket: zmq_anyio.Socket, - get_task_group: Callable[[], TaskGroup], ): assert current_thread() == main_thread() self._context: zmq.Context = context # type: ignore[type-arg] self._shell_socket = shell_socket - self._get_task_group = get_task_group self._cache: dict[str, Subshell] = {} self._lock_cache = Lock() self._lock_shell_socket = Lock() @@ -91,9 +89,9 @@ def close(self) -> None: break self._stop_subshell(subshell) - async def get_control_other_socket(self, task_group: TaskGroup) -> zmq_anyio.Socket: + async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket: if not self._control_other_socket.started.is_set(): - task_group.start_soon(self._control_other_socket.start) + thread.start_soon(self._control_other_socket.start) await self._control_other_socket.started.wait() return self._control_other_socket @@ -134,7 +132,7 @@ async def listen_from_control(self, subshell_task: t.Any, thread: BaseThread) -> assert current_thread().name == SHELL_CHANNEL_THREAD_NAME if not self._control_shell_channel_socket.started.is_set(): - thread.get_task_group().start_soon(self._control_shell_channel_socket.start) + thread.start_soon(self._control_shell_channel_socket.start) await self._control_shell_channel_socket.started.wait() socket = self._control_shell_channel_socket while True: @@ -200,8 +198,8 @@ async def _create_subshell(self, subshell_task: t.Any) -> str: await self._send_stream.send(subshell_id) address = self._get_inproc_socket_address(subshell_id) - thread.add_task(thread.create_pair_socket, self._context, address) - thread.add_task(subshell_task, subshell_id) + thread.start_soon(partial(thread.create_pair_socket, self._context, address)) + thread.start_soon(partial(subshell_task, subshell_id)) thread.start() return subshell_id diff --git a/ipykernel/thread.py b/ipykernel/thread.py index f55ee7c7..24e98b36 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -1,9 +1,12 @@ """Base class for threads.""" -import typing as t -from threading import Event, Thread +from __future__ import annotations + +from queue import Queue +from collections.abc import Awaitable +from threading import Thread +from typing import Callable from anyio import create_task_group, run, to_thread -from anyio.abc import TaskGroup CONTROL_THREAD_NAME = "Control" SHELL_CHANNEL_THREAD_NAME = "Shell channel" @@ -17,26 +20,22 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.pydev_do_not_trace = True self.is_pydev_daemon_thread = True - self.__stop = Event() - self._tasks_and_args: list[tuple[t.Any, t.Any]] = [] - - def get_task_group(self) -> TaskGroup: - return self._task_group + self._tasks: Queue[Callable[[], Awaitable[None]] | None] = Queue() - def add_task(self, task: t.Any, *args: t.Any) -> None: - # May only add tasks before the thread is started. - self._tasks_and_args.append((task, args)) + def start_soon(self, task: Callable[[], Awaitable[None]] | None) -> None: + self._tasks.put(task) - def run(self) -> t.Any: + def run(self) -> None: """Run the thread.""" - return run(self._main) + run(self._main) async def _main(self) -> None: async with create_task_group() as tg: - self._task_group = tg - for task, args in self._tasks_and_args: - tg.start_soon(task, *args) - await to_thread.run_sync(self.__stop.wait) + while True: + task = await to_thread.run_sync(self._tasks.get) + if task is None: + break + tg.start_soon(task) tg.cancel_scope.cancel() def stop(self) -> None: @@ -44,4 +43,4 @@ def stop(self) -> None: This method is threadsafe. """ - self.__stop.set() + self._tasks.put(None)