diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index cc6ec8edc6..cd8a9de71f 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -232,11 +232,7 @@ async def _async_start_kernel( # type:ignore[override] kwargs["kernel_id"] = kernel_id kernel_id = await self.pinned_superclass._async_start_kernel(self, **kwargs) self._kernel_connections[kernel_id] = 0 - task = asyncio.create_task(self._finish_kernel_start(kernel_id)) - if not getattr(self, "use_pending_kernels", None): - await task - else: - self._pending_kernel_tasks[kernel_id] = task + # add busy/activity markers: kernel = self.get_kernel(kernel_id) kernel.execution_state = "starting" # type:ignore[attr-defined] @@ -250,6 +246,12 @@ async def _async_start_kernel( # type:ignore[override] if env and isinstance(env, dict): # type:ignore[unreachable] self.log.debug("Kernel argument 'env' passed with: %r", list(env.keys())) # type:ignore[unreachable] + task = asyncio.create_task(self._finish_kernel_start(kernel_id)) + if not getattr(self, "use_pending_kernels", None): + await task + else: + self._pending_kernel_tasks[kernel_id] = task + # Increase the metric of number of kernels running # for the relevant kernel type by 1 KERNEL_CURRENTLY_RUNNING_TOTAL.labels(type=self._kernels[kernel_id].kernel_name).inc() @@ -537,6 +539,40 @@ def _check_kernel_id(self, kernel_id): raise web.HTTPError(404, "Kernel does not exist: %s" % kernel_id) # monitoring activity: + untracked_message_types = List( + trait=Unicode(), + config=True, + default_value=[ + "comm_info_request", + "comm_info_reply", + "kernel_info_request", + "kernel_info_reply", + "shutdown_request", + "shutdown_reply", + "interrupt_request", + "interrupt_reply", + "debug_request", + "debug_reply", + "stream", + "display_data", + "update_display_data", + "execute_input", + "execute_result", + "error", + "status", + "clear_output", + "debug_event", + "input_request", + "input_reply", + ], + help="""List of kernel message types excluded from user activity tracking. + + This should be a superset of the message types sent on any channel other + than the shell channel.""", + ) + + def track_message_type(self, message_type): + return message_type not in self.untracked_message_types def start_watching_activity(self, kernel_id): """Start watching IOPub messages on a kernel for activity. @@ -557,15 +593,27 @@ def start_watching_activity(self, kernel_id): def record_activity(msg_list): """Record an IOPub message arriving from a kernel""" - self.last_kernel_activity = kernel.last_activity = utcnow() - idents, fed_msg_list = session.feed_identities(msg_list) msg = session.deserialize(fed_msg_list, content=False) msg_type = msg["header"]["msg_type"] + parent_msg_type = msg.get("parent_header", {}).get("msg_type", None) + if ( + self.track_message_type(msg_type) + or self.track_message_type(parent_msg_type) + or kernel.execution_state == "busy" + ): + self.last_kernel_activity = kernel.last_activity = utcnow() if msg_type == "status": msg = session.deserialize(fed_msg_list) - kernel.execution_state = msg["content"]["execution_state"] + execution_state = msg["content"]["execution_state"] + if self.track_message_type(parent_msg_type): + kernel.execution_state = execution_state + elif kernel.execution_state == "starting" and execution_state != "starting": + # We always normalize post-starting execution state to "idle" + # unless we know that the status is in response to one of our + # tracked message types. + kernel.execution_state = "idle" self.log.debug( "activity on %s: %s (%s)", kernel_id, diff --git a/tests/services/kernels/test_cull.py b/tests/services/kernels/test_cull.py index 50ecbf2b96..5b0b8fd9a0 100644 --- a/tests/services/kernels/test_cull.py +++ b/tests/services/kernels/test_cull.py @@ -1,7 +1,9 @@ import asyncio +import datetime import json import os import platform +import uuid import warnings import jupyter_client @@ -94,6 +96,83 @@ async def test_cull_idle(jp_fetch, jp_ws_fetch): assert culled +@pytest.mark.parametrize( + "jp_server_config", + [ + # Test the synchronous case + Config( + { + "ServerApp": { + "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager", + "MappingKernelManager": { + "cull_idle_timeout": CULL_TIMEOUT, + "cull_interval": CULL_INTERVAL, + "cull_connected": True, + }, + } + } + ), + # Test the async case + Config( + { + "ServerApp": { + "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager", + "AsyncMappingKernelManager": { + "cull_idle_timeout": CULL_TIMEOUT, + "cull_interval": CULL_INTERVAL, + "cull_connected": True, + }, + } + } + ), + ], +) +async def test_cull_connected(jp_fetch, jp_ws_fetch): + r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True) + kernel = json.loads(r.body.decode()) + kid = kernel["id"] + + # Open a websocket connection. + ws = await jp_ws_fetch("api", "kernels", kid, "channels") + session_id = uuid.uuid1().hex + message_id = uuid.uuid1().hex + await ws.write_message( + json.dumps( + { + "channel": "shell", + "header": { + "date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(), + "session": session_id, + "msg_id": message_id, + "msg_type": "execute_request", + "username": "", + "version": "5.2", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": f"import time\ntime.sleep({CULL_TIMEOUT-1})", + "silent": False, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + } + ) + ) + + r = await jp_fetch("api", "kernels", kid, method="GET") + model = json.loads(r.body.decode()) + assert model["connections"] == 1 + culled = await get_cull_status( + kid, jp_fetch + ) # connected, but code cell still running. Should not be culled + assert not culled + culled = await get_cull_status(kid, jp_fetch) # still connected, but idle... should be culled + assert culled + ws.close() + + async def test_cull_idle_disable(jp_fetch, jp_ws_fetch, jp_kernelspec_with_metadata): r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True) kernel = json.loads(r.body.decode()) diff --git a/tests/services/kernels/test_execution_state.py b/tests/services/kernels/test_execution_state.py new file mode 100644 index 0000000000..50155ec76f --- /dev/null +++ b/tests/services/kernels/test_execution_state.py @@ -0,0 +1,146 @@ +import asyncio +import datetime +import json +import os +import platform +import time +import uuid +import warnings + +import jupyter_client +import pytest +from flaky import flaky +from tornado.httpclient import HTTPClientError +from traitlets.config import Config + +MAX_POLL_ATTEMPTS = 10 +POLL_INTERVAL = 1 +MINIMUM_CONSISTENT_COUNT = 4 + + +@flaky +async def test_execution_state(jp_fetch, jp_ws_fetch): + r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True) + kernel = json.loads(r.body.decode()) + kid = kernel["id"] + + # Open a websocket connection. + ws = await jp_ws_fetch("api", "kernels", kid, "channels") + session_id = uuid.uuid1().hex + message_id = uuid.uuid1().hex + await ws.write_message( + json.dumps( + { + "channel": "shell", + "header": { + "date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(), + "session": session_id, + "msg_id": message_id, + "msg_type": "execute_request", + "username": "", + "version": "5.2", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": "while True:\n\tpass", + "silent": False, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + } + ) + ) + await poll_for_parent_message_status(kid, message_id, "busy", ws) + es = await get_execution_state(kid, jp_fetch) + assert es == "busy" + + message_id_2 = uuid.uuid1().hex + await ws.write_message( + json.dumps( + { + "channel": "control", + "header": { + "date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(), + "session": session_id, + "msg_id": message_id_2, + "msg_type": "debug_request", + "username": "", + "version": "5.2", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "type": "request", + "command": "debugInfo", + }, + "buffers": [], + } + ) + ) + await poll_for_parent_message_status(kid, message_id_2, "idle", ws) + es = await get_execution_state(kid, jp_fetch) + + # Verify that the overall kernel status is still "busy" even though one + # "idle" response was already seen for the second execute request. + assert es == "busy" + + await jp_fetch( + "api", + "kernels", + kid, + "interrupt", + method="POST", + allow_nonstandard_methods=True, + ) + + await poll_for_parent_message_status(kid, message_id, "idle", ws) + es = await get_execution_state(kid, jp_fetch) + assert es == "idle" + ws.close() + + +async def get_execution_state(kid, jp_fetch): + # There is an inherent race condition when getting the kernel execution status + # where we might fetch the status right before an expected state change occurs. + # + # To work-around this, we don't return the status until we've been able to fetch + # it twice in a row and get the same result both times. + last_execution_states = [] + + for _ in range(MAX_POLL_ATTEMPTS): + r = await jp_fetch("api", "kernels", kid, method="GET") + model = json.loads(r.body.decode()) + execution_state = model["execution_state"] + last_execution_states.append(execution_state) + consistent_count = 0 + last_execution_state = None + for es in last_execution_states: + if es != last_execution_state: + consistent_count = 0 + last_execution_state = es + consistent_count += 1 + if consistent_count >= MINIMUM_CONSISTENT_COUNT: + return es + time.sleep(POLL_INTERVAL) + + raise AssertionError("failed to get a consistent execution state") + + +async def poll_for_parent_message_status(kid, parent_message_id, target_status, ws): + while True: + resp = await ws.read_message() + resp_json = json.loads(resp) + print(resp_json) + parent_message = resp_json.get("parent_header", {}).get("msg_id", None) + if parent_message != parent_message_id: + continue + + response_type = resp_json.get("header", {}).get("msg_type", None) + if response_type != "status": + continue + + execution_state = resp_json.get("content", {}).get("execution_state", "") + if execution_state == target_status: + return