diff --git a/.travis.yml b/.travis.yml index 727d224..47ef2ef 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: python python: - "3.5" -install: pip install ipykernel jupyter_client jsonschema +install: pip install ipykernel jupyter_kernel_mgmt>=0.1.1 jupyter_protocol jsonschema script: python test_ipykernel.py diff --git a/flit.ini b/flit.ini deleted file mode 100644 index fe1e2e9..0000000 --- a/flit.ini +++ /dev/null @@ -1,13 +0,0 @@ -[metadata] -module = jupyter_kernel_test -author = Jupyter Development Team -author-email = jupyter@googlegroups.com -home-page = https://github.com/jupyter/jupyter_kernel_test -description-file = README.rst -classifiers = License :: OSI Approved :: BSD License - Intended Audience :: Developers - Programming Language :: Python :: 3 - Topic :: Software Development :: Testing -requires-python = >=3.4 -requires = jupyter_client - jsonschema diff --git a/jupyter_kernel_test/__init__.py b/jupyter_kernel_test/__init__.py index d1638d9..e5b8724 100644 --- a/jupyter_kernel_test/__init__.py +++ b/jupyter_kernel_test/__init__.py @@ -4,9 +4,11 @@ # Distributed under the terms of the Modified BSD License. from unittest import TestCase, SkipTest -from queue import Empty -from jupyter_client.manager import start_new_kernel +from jupyter_kernel_mgmt.discovery import KernelFinder +from jupyter_kernel_mgmt.client import BlockingKernelClient, ErrorInKernel +from tornado import gen +from tornado.ioloop import IOLoop from .msgspec_v5 import validate_message TIMEOUT = 15 @@ -16,72 +18,86 @@ class KernelTests(TestCase): kernel_name = "" + @classmethod + def launch_kernel(cls): + kf = KernelFinder.from_entrypoints() + name = cls.kernel_name + if '/' not in name: + name = 'spec/' + name + return kf.launch(name=name) + @classmethod def setUpClass(cls): - cls.km, cls.kc = start_new_kernel(kernel_name=cls.kernel_name) + conn_info, km = cls.launch_kernel() + cls.kc = BlockingKernelClient(conn_info, manager=km) + cls.kc.wait_for_ready(timeout=TIMEOUT) + cls.kc.allow_stdin = False @classmethod def tearDownClass(cls): - cls.kc.stop_channels() - cls.km.shutdown_kernel() - - def flush_channels(self): - for channel in (self.kc.shell_channel, self.kc.iopub_channel): - while True: - try: - msg = channel.get_msg(block=True, timeout=0.1) - except Empty: - break - else: - validate_message(msg) + cls.kc.shutdown_or_terminate() + cls.kc.close() + + def _wait_for_reply(self, future, timeout=TIMEOUT): + return IOLoop.current().run_sync(lambda: future, timeout=timeout) language_name = "" file_extension = "" def test_kernel_info(self): - self.flush_channels() - - msg_id = self.kc.kernel_info() - reply = self.kc.get_shell_msg(timeout=TIMEOUT) - validate_message(reply, 'kernel_info_reply', msg_id) + fut = self.kc.kernel_info(reply=False) + reply = self._wait_for_reply(fut) + validate_message(reply, 'kernel_info_reply', fut.jupyter_msg_id) if self.language_name: - self.assertEqual(reply['content']['language_info']['name'], + self.assertEqual(reply.content['language_info']['name'], self.language_name) if self.file_extension: - self.assertEqual(reply['content']['language_info']['file_extension'], + self.assertEqual(reply.content['language_info']['file_extension'], self.file_extension) - self.assertTrue(reply['content']['language_info']['file_extension'].startswith(".")) + self.assertTrue(reply.content['language_info']['file_extension'].startswith(".")) def execute_helper(self, code, timeout=TIMEOUT, silent=False, store_history=True, stop_on_error=True): - msg_id = self.kc.execute(code=code, silent=silent, - store_history=store_history, - stop_on_error=stop_on_error) - - reply = self.kc.get_shell_msg(timeout=timeout) - validate_message(reply, 'execute_reply', msg_id) - - busy_msg = self.kc.iopub_channel.get_msg(timeout=1) - validate_message(busy_msg, 'status', msg_id) - self.assertEqual(busy_msg['content']['execution_state'], 'busy') - output_msgs = [] - while True: - msg = self.kc.iopub_channel.get_msg(timeout=0.1) - validate_message(msg, msg['msg_type'], msg_id) - if msg['msg_type'] == 'status': - self.assertEqual(msg['content']['execution_state'], 'idle') - break - elif msg['msg_type'] == 'execute_input': - self.assertEqual(msg['content']['code'], code) - continue - output_msgs.append(msg) - - return reply, output_msgs + # Give the client a chance to flush messages already arrived + IOLoop.current().run_sync(lambda: gen.sleep(0.1)) + self.kc.loop_client.add_handler('iopub', output_msgs.append) + try: + fut = self.kc.execute( + code=code, silent=silent, store_history=store_history, + stop_on_error=stop_on_error, interrupt_timeout=timeout, + reply=False, + ) + try: + reply = self._wait_for_reply(fut, timeout+5) + except ErrorInKernel as e: + reply = e.reply_msg + finally: + self.kc.loop_client.add_handler('iopub', output_msgs.append) + + validate_message(reply, 'execute_reply') + + busy_msg = output_msgs[0] + validate_message(busy_msg, 'status', fut.jupyter_msg_id) + self.assertEqual(busy_msg.content['execution_state'], 'busy') + + real_output_msgs = [] # Excluding status and execute_input + + for msg in output_msgs[1:]: + msg_type = msg.header['msg_type'] + validate_message(msg, msg_type, fut.jupyter_msg_id) + if msg_type == 'status': + self.assertEqual(msg.content['execution_state'], 'idle') + elif msg_type == 'execute_input': + self.assertEqual(msg.content['code'], code) + else: + real_output_msgs.append(msg) + + return reply, real_output_msgs code_hello_world = "" @@ -89,15 +105,14 @@ def test_execute_stdout(self): if not self.code_hello_world: raise SkipTest - self.flush_channels() reply, output_msgs = self.execute_helper(code=self.code_hello_world) - self.assertEqual(reply['content']['status'], 'ok') + self.assertEqual(reply.content['status'], 'ok') self.assertGreaterEqual(len(output_msgs), 1) for msg in output_msgs: - if (msg['msg_type'] == 'stream') and (msg['content']['name'] == 'stdout'): - self.assertIn('hello, world', msg['content']['text']) + if (msg.header['msg_type'] == 'stream') and (msg.content['name'] == 'stdout'): + self.assertIn('hello, world', msg.content['text']) break else: self.assertTrue(False, "Expected one output message of type 'stream' and 'content.name'='stdout'") @@ -108,15 +123,14 @@ def test_execute_stderr(self): if not self.code_stderr: raise SkipTest - self.flush_channels() reply, output_msgs = self.execute_helper(code=self.code_stderr) - self.assertEqual(reply['content']['status'], 'ok') + self.assertEqual(reply.content['status'], 'ok') self.assertGreaterEqual(len(output_msgs), 1) for msg in output_msgs: - if (msg['msg_type'] == 'stream') and (msg['content']['name'] == 'stderr'): + if (msg.header['msg_type'] == 'stream') and (msg.content['name'] == 'stderr'): break else: self.assertTrue(False, "Expected one output message of type 'stream' and 'content.name'='stderr'") @@ -129,11 +143,11 @@ def test_completion(self): for sample in self.completion_samples: with self.subTest(text=sample['text']): - msg_id = self.kc.complete(sample['text']) - reply = self.kc.get_shell_msg() - validate_message(reply, 'complete_reply', msg_id) + fut = self.kc.complete(sample['text'], reply=False) + reply = self._wait_for_reply(fut) + validate_message(reply, 'complete_reply', fut.jupyter_msg_id) if 'matches' in sample: - self.assertEqual(set(reply['content']['matches']), + self.assertEqual(set(reply.content['matches']), set(sample['matches'])) complete_code_samples = [] @@ -141,13 +155,13 @@ def test_completion(self): invalid_code_samples = [] def check_is_complete(self, sample, status): - msg_id = self.kc.is_complete(sample) - reply = self.kc.get_shell_msg() - validate_message(reply, 'is_complete_reply', msg_id) - if reply['content']['status'] != status: + fut = self.kc.is_complete(sample, reply=False) + reply = self._wait_for_reply(fut) + validate_message(reply, 'is_complete_reply', fut.jupyter_msg_id) + if reply.content['status'] != status: msg = "For code sample\n {!r}\nExpected {!r}, got {!r}." raise AssertionError(msg.format(sample, status, - reply['content']['status'])) + reply.content['status'])) def test_is_complete(self): if not (self.complete_code_samples @@ -155,8 +169,6 @@ def test_is_complete(self): or self.invalid_code_samples): raise SkipTest - self.flush_channels() - with self.subTest(status="complete"): for sample in self.complete_code_samples: self.check_is_complete(sample, 'complete') @@ -175,11 +187,9 @@ def test_pager(self): if not self.code_page_something: raise SkipTest - self.flush_channels() - reply, output_msgs = self.execute_helper(self.code_page_something) - self.assertEqual(reply['content']['status'], 'ok') - payloads = reply['content']['payload'] + self.assertEqual(reply.content['status'], 'ok') + payloads = reply.content['payload'] self.assertEqual(len(payloads), 1) self.assertEqual(payloads[0]['source'], 'page') mimebundle = payloads[0]['data'] @@ -191,12 +201,10 @@ def test_error(self): if not self.code_generate_error: raise SkipTest - self.flush_channels() - reply, output_msgs = self.execute_helper(self.code_generate_error) - self.assertEqual(reply['content']['status'], 'error') + self.assertEqual(reply.content['status'], 'error') self.assertEqual(len(output_msgs), 1) - self.assertEqual(output_msgs[0]['msg_type'], 'error') + self.assertEqual(output_msgs[0].header['msg_type'], 'error') code_execute_result = [] @@ -206,16 +214,14 @@ def test_execute_result(self): for sample in self.code_execute_result: with self.subTest(code=sample['code']): - self.flush_channels() - reply, output_msgs = self.execute_helper(sample['code']) - self.assertEqual(reply['content']['status'], 'ok') + self.assertEqual(reply.content['status'], 'ok') self.assertGreaterEqual(len(output_msgs), 1) - self.assertEqual(output_msgs[0]['msg_type'], 'execute_result') - self.assertIn('text/plain', output_msgs[0]['content']['data']) - self.assertEqual(output_msgs[0]['content']['data']['text/plain'], + self.assertEqual(output_msgs[0].header['msg_type'], 'execute_result') + self.assertIn('text/plain', output_msgs[0].content['data']) + self.assertEqual(output_msgs[0].content['data']['text/plain'], sample['result']) code_display_data = [] @@ -226,30 +232,26 @@ def test_display_data(self): for sample in self.code_display_data: with self.subTest(code=sample['code']): - self.flush_channels() reply, output_msgs = self.execute_helper(sample['code']) - self.assertEqual(reply['content']['status'], 'ok') + self.assertEqual(reply.content['status'], 'ok') self.assertGreaterEqual(len(output_msgs), 1) - self.assertEqual(output_msgs[0]['msg_type'], 'display_data') - self.assertIn(sample['mime'], output_msgs[0]['content']['data']) + self.assertEqual(output_msgs[0].header['msg_type'], 'display_data') + self.assertIn(sample['mime'], output_msgs[0].content['data']) # this should match one of the values in code_execute_result code_history_pattern = "" supported_history_operations = () def history_helper(self, execute_first, timeout=TIMEOUT, **histargs): - self.flush_channels() - for code in execute_first: reply, output_msgs = self.execute_helper(code) - self.flush_channels() - msg_id = self.kc.history(**histargs) + fut = self.kc.history(**histargs, reply=False) - reply = self.kc.get_shell_msg(timeout=timeout) - validate_message(reply, 'history_reply', msg_id) + reply = self._wait_for_reply(fut, timeout=timeout) + validate_message(reply, 'history_reply', fut.jupyter_msg_id) return reply @@ -268,15 +270,15 @@ def test_history(self): raise SkipTest reply = self.history_helper(codes, output=False, raw=True, hist_access_type="tail", n=n) - self.assertEqual(len(reply['content']['history']), n) - self.assertEqual(len(reply['content']['history'][0]), 3) - self.assertEqual(codes, [h[2] for h in reply['content']['history']]) + self.assertEqual(len(reply.content['history']), n) + self.assertEqual(len(reply.content['history'][0]), 3) + self.assertEqual(codes, [h[2] for h in reply.content['history']]) - session, start = reply['content']['history'][0][0:2] + session, start = reply.content['history'][0][0:2] with self.subTest(output=True): reply = self.history_helper(codes, output=True, raw=True, hist_access_type="tail", n=n) - self.assertEqual(len(reply['content']['history'][0][2]), 2) + self.assertEqual(len(reply.content['history'][0][2]), 2) with self.subTest(hist_access_type="range"): if 'range' not in self.supported_history_operations: @@ -287,9 +289,9 @@ def test_history(self): hist_access_type="range", session=session, start=start, stop=start+1) - self.assertEqual(len(reply['content']['history']), 1) - self.assertEqual(reply['content']['history'][0][0], session) - self.assertEqual(reply['content']['history'][0][1], start) + self.assertEqual(len(reply.content['history']), 1) + self.assertEqual(reply.content['history'][0][0], session) + self.assertEqual(reply.content['history'][0][1], start) with self.subTest(hist_access_type="search"): if not self.code_history_pattern: @@ -301,19 +303,23 @@ def test_history(self): reply = self.history_helper(codes, output=False, raw=True, hist_access_type="search", pattern=self.code_history_pattern) - self.assertGreaterEqual(len(reply['content']['history']), 1) + self.assertGreaterEqual(len(reply.content['history']), 1) with self.subTest(subsearch="unique"): reply = self.history_helper(codes, output=False, raw=True, hist_access_type="search", pattern=self.code_history_pattern, unique=True) - self.assertEqual(len(reply['content']['history']), 1) + entries = [x[2] for x in reply.content['history']] + if len(entries) > len(set(entries)): + raise AssertionError("History search results not unique: {}" + .format(sorted(entries))) + with self.subTest(subsearch="n"): reply = self.history_helper(codes, output=False, raw=True, hist_access_type="search", pattern=self.code_history_pattern, n=3) - self.assertEqual(len(reply['content']['history']), 3) + self.assertEqual(len(reply.content['history']), 3) code_inspect_sample = "" @@ -321,14 +327,13 @@ def test_inspect(self): if not self.code_inspect_sample: raise SkipTest - self.flush_channels() - msg_id = self.kc.inspect(self.code_inspect_sample) - reply = self.kc.get_shell_msg(timeout=TIMEOUT) - validate_message(reply, 'inspect_reply', msg_id) + fut = self.kc.inspect(self.code_inspect_sample, reply=False) + reply = self._wait_for_reply(fut) + validate_message(reply, 'inspect_reply', fut.jupyter_msg_id) - self.assertEqual(reply['content']['status'], 'ok') - self.assertTrue(reply['content']['found']) - self.assertGreaterEqual(len(reply['content']['data']), 1) + self.assertEqual(reply.content['status'], 'ok') + self.assertTrue(reply.content['found']) + self.assertGreaterEqual(len(reply.content['data']), 1) code_clear_output = "" @@ -336,10 +341,9 @@ def test_clear_output(self): if not self.code_clear_output: raise SkipTest - self.flush_channels() reply, output_msgs = self.execute_helper(code=self.code_clear_output) - self.assertEqual(reply['content']['status'], 'ok') + self.assertEqual(reply.content['status'], 'ok') self.assertGreaterEqual(len(output_msgs), 1) - self.assertEqual(output_msgs[0]['msg_type'], 'clear_output') + self.assertEqual(output_msgs[0].header['msg_type'], 'clear_output') diff --git a/jupyter_kernel_test/msgspec_v5.py b/jupyter_kernel_test/msgspec_v5.py index e36f0f8..399a39a 100644 --- a/jupyter_kernel_test/msgspec_v5.py +++ b/jupyter_kernel_test/msgspec_v5.py @@ -4,6 +4,7 @@ # Distributed under the terms of the Modified BSD License. from jsonschema import Draft4Validator, ValidationError +from jupyter_protocol.messages import Message import re protocol_version = (5, 1) @@ -90,21 +91,22 @@ def get_abort_reply_validator(version_minor): } def validate_message(msg, msg_type=None, parent_id=None): - msg_structure_validator.validate(msg) + if not isinstance(msg, Message): + raise ValidationError("Expected a Message instance, got {}".format(type(msg))) - msg_version_s = msg['header']['version'] + msg_version_s = msg.header['version'] m = re.match(r'(\d+)\.(\d+)', msg_version_s) if not m: - raise ValidationError("Version {} not like 'x.y'") + raise ValidationError("Version {} not like 'x.y'".format(msg_version_s)) version_minor = int(m.group(2)) if msg_type is not None: - if msg['header']['msg_type'] != msg_type: + if msg.header['msg_type'] != msg_type: raise ValidationError("Message type {!r} != {!r}".format( - msg['header']['msg_type'], msg_type + msg.header['msg_type'], msg_type )) else: - msg_type = msg['header']['msg_type'] + msg_type = msg.header['msg_type'] # Check for unexpected fields, unless it's a newer protocol version if version_minor <= protocol_version[1]: @@ -112,18 +114,18 @@ def validate_message(msg, msg_type=None, parent_id=None): if unx_top: raise ValidationError("Unexpected keys: {}".format(unx_top)) - unx_header = set(msg['header']) - set(header_part['properties']) + unx_header = set(msg.header) - set(header_part['properties']) if unx_header: raise ValidationError("Unexpected keys in header: {}".format(unx_header)) # Check the parent id - if parent_id and msg['parent_header']['msg_id'] != parent_id: + if parent_id and msg.parent_header['msg_id'] != parent_id: raise ValidationError("Parent header does not match expected") if msg_type in reply_msgs_using_status: # Most _reply messages have common 'error' and 'abort' structures try: - status = msg['content']['status'] + status = msg.content['status'] except KeyError as e: raise ValidationError(str(e)) if status == 'error': @@ -138,7 +140,7 @@ def validate_message(msg, msg_type=None, parent_id=None): else: content_vdor = get_msg_content_validator(msg_type, version_minor) - content_vdor.validate(msg['content']) + content_vdor.validate(msg.content) # Shell messages ---------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4118032 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = ["flit"] +build-backend = "flit.buildapi" + +[tool.flit.metadata] +module = "jupyter_kernel_test" +author = "Jupyter Development Team" +author-email = "jupyter@googlegroups.com" +home-page = "https://github.com/jupyter/jupyter_kernel_test" +description-file = "README.rst" +classifiers = [ + "License :: OSI Approved :: BSD License", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Testing" +] +requires-python = ">=3.4" +requires = [ + "jupyter_kernel_mgmt>=0.1.1", + "jupyter_protocol", + "jsonschema" +] +