diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst index 3796a2a4fb..edadbe3ffc 100644 --- a/docs/source/public_server.rst +++ b/docs/source/public_server.rst @@ -343,6 +343,35 @@ single-tab mode: }); +Using a gateway server for kernel management +-------------------------------------------- + +You are now able to redirect the management of your kernels to a Gateway Server +(i.e., `Jupyter Kernel Gateway `_ or +`Jupyter Enterprise Gateway `_) +simply by specifying a Gateway url via the following command-line option: + + .. code-block:: bash + + $ jupyter notebook --gateway-url=http://my-gateway-server:8888 + +the environment: + + .. code-block:: bash + + JUPYTER_GATEWAY_URL=http://my-gateway-server:8888 + +or in :file:`jupyter_notebook_config.py`: + + .. code-block:: python + + c.GatewayClient.url = http://my-gateway-server:8888 + +When provided, all kernel specifications will be retrieved from the specified Gateway server and all +kernels will be managed by that server. This option enables the ability to target kernel processes +against managed clusters while allowing for the notebook's management to remain local to the Notebook +server. + Known issues ------------ diff --git a/notebook/gateway/__init__.py b/notebook/gateway/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/notebook/gateway/handlers.py b/notebook/gateway/handlers.py new file mode 100644 index 0000000000..8e09b10861 --- /dev/null +++ b/notebook/gateway/handlers.py @@ -0,0 +1,207 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import logging + +from ..base.handlers import IPythonHandler +from ..utils import url_path_join + +from tornado import gen, web +from tornado.concurrent import Future +from tornado.ioloop import IOLoop +from tornado.websocket import WebSocketHandler, websocket_connect +from tornado.httpclient import HTTPRequest +from tornado.escape import url_escape, json_decode, utf8 + +from ipython_genutils.py3compat import cast_unicode +from jupyter_client.session import Session +from traitlets.config.configurable import LoggingConfigurable + +from .managers import GatewayClient + + +class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler): + + session = None + gateway = None + kernel_id = None + + def set_default_headers(self): + """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets""" + pass + + def get_compression_options(self): + # use deflate compress websocket + return {} + + def authenticate(self): + """Run before finishing the GET request + + Extend this method to add logic that should fire before + the websocket finishes completing. + """ + # authenticate the request before opening the websocket + if self.get_current_user() is None: + self.log.warning("Couldn't authenticate WebSocket connection") + raise web.HTTPError(403) + + if self.get_argument('session_id', False): + self.session.session = cast_unicode(self.get_argument('session_id')) + else: + self.log.warning("No session ID specified") + + def initialize(self): + self.log.debug("Initializing websocket connection %s", self.request.path) + self.session = Session(config=self.config) + self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url) + + @gen.coroutine + def get(self, kernel_id, *args, **kwargs): + self.authenticate() + self.kernel_id = cast_unicode(kernel_id, 'ascii') + super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs) + + def open(self, kernel_id, *args, **kwargs): + """Handle web socket connection open to notebook server and delegate to gateway web socket handler """ + self.gateway.on_open( + kernel_id=kernel_id, + message_callback=self.write_message, + compression_options=self.get_compression_options() + ) + + def on_message(self, message): + """Forward message to gateway web socket handler.""" + self.log.debug("Sending message to gateway: {}".format(message)) + self.gateway.on_message(message) + + def write_message(self, message, binary=False): + """Send message back to notebook client. This is called via callback from self.gateway._read_messages.""" + self.log.debug("Receiving message from gateway: {}".format(message)) + if self.ws_connection: # prevent WebSocketClosedError + super(WebSocketChannelsHandler, self).write_message(message, binary=binary) + elif self.log.isEnabledFor(logging.DEBUG): + msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message))) + self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary)) + + def on_close(self): + self.log.debug("Closing websocket connection %s", self.request.path) + self.gateway.on_close() + super(WebSocketChannelsHandler, self).on_close() + + @staticmethod + def _get_message_summary(message): + summary = [] + message_type = message['msg_type'] + summary.append('type: {}'.format(message_type)) + + if message_type == 'status': + summary.append(', state: {}'.format(message['content']['execution_state'])) + elif message_type == 'error': + summary.append(', {}:{}:{}'.format(message['content']['ename'], + message['content']['evalue'], + message['content']['traceback'])) + else: + summary.append(', ...') # don't display potentially sensitive data + + return ''.join(summary) + + +class GatewayWebSocketClient(LoggingConfigurable): + """Proxy web socket connection to a kernel/enterprise gateway.""" + + def __init__(self, **kwargs): + super(GatewayWebSocketClient, self).__init__(**kwargs) + self.kernel_id = None + self.ws = None + self.ws_future = Future() + self.ws_future_cancelled = False + + @gen.coroutine + def _connect(self, kernel_id): + self.kernel_id = kernel_id + ws_url = url_path_join( + GatewayClient.instance().ws_url, + GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels' + ) + self.log.info('Connecting to {}'.format(ws_url)) + kwargs = {} + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + + request = HTTPRequest(ws_url, **kwargs) + self.ws_future = websocket_connect(request) + self.ws_future.add_done_callback(self._connection_done) + + def _connection_done(self, fut): + if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError + self.ws = fut.result() + self.log.debug("Connection is ready: ws: {}".format(self.ws)) + else: + self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. " + "Kernel with ID '{}' may not be terminated on GatewayClient: {}". + format(self.kernel_id, GatewayClient.instance().url)) + + def _disconnect(self): + if self.ws is not None: + # Close connection + self.ws.close() + elif not self.ws_future.done(): + # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally + self.ws_future.cancel() + self.ws_future_cancelled = True + self.log.debug("_disconnect: ws_future_cancelled: {}".format(self.ws_future_cancelled)) + + @gen.coroutine + def _read_messages(self, callback): + """Read messages from gateway server.""" + while True: + message = None + if not self.ws_future_cancelled: + try: + message = yield self.ws.read_message() + except Exception as e: + self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True) + if message is None: + break + callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) + else: # ws cancelled - stop reading + break + + def on_open(self, kernel_id, message_callback, **kwargs): + """Web socket connection open against gateway server.""" + self._connect(kernel_id) + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._read_messages(message_callback) + ) + + def on_message(self, message): + """Send message to gateway server.""" + if self.ws is None: + loop = IOLoop.current() + loop.add_future( + self.ws_future, + lambda future: self._write_message(message) + ) + else: + self._write_message(message) + + def _write_message(self, message): + """Send message to gateway server.""" + try: + if not self.ws_future_cancelled: + self.ws.write_message(message) + except Exception as e: + self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True) + + def on_close(self): + """Web socket closed event.""" + self._disconnect() + + +from ..services.kernels.handlers import _kernel_id_regex + +default_handlers = [ + (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler), +] diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py new file mode 100644 index 0000000000..73af7d9799 --- /dev/null +++ b/notebook/gateway/managers.py @@ -0,0 +1,555 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import json + +from socket import gaierror +from tornado import gen, web +from tornado.escape import json_encode, json_decode, url_escape +from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError +from tornado.simple_httpclient import HTTPTimeoutError + +from ..services.kernels.kernelmanager import MappingKernelManager +from ..services.sessions.sessionmanager import SessionManager + +from jupyter_client.kernelspec import KernelSpecManager +from ..utils import url_path_join + +from traitlets import Instance, Unicode, Float, Bool, default, validate, TraitError +from traitlets.config import SingletonConfigurable + + +class GatewayClient(SingletonConfigurable): + """This class manages the configuration. It's its own singleton class so that we + can share these values across all objects. It also contains some helper methods + to build request arguments out of the various config options. + + """ + + url = Unicode(default_value=None, allow_none=True, config=True, + help="""The url of the Kernel or Enterprise Gateway server where + kernel specifications are defined and kernel management takes place. + If defined, this Notebook server acts as a proxy for all kernel + management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) + """ + ) + + url_env = 'JUPYTER_GATEWAY_URL' + @default('url') + def _url_default(self): + return os.environ.get(self.url_env) + + @validate('url') + def _url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'http' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('http'): + raise TraitError("GatewayClient url must start with 'http': '%r'" % value) + return value + + ws_url = Unicode(default_value=None, allow_none=True, config=True, + help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value + will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) + """ + ) + + ws_url_env = 'JUPYTER_GATEWAY_WS_URL' + @default('ws_url') + def _ws_url_default(self): + default_value = os.environ.get(self.ws_url_env) + if default_value is None: + if self.gateway_enabled: + default_value = self.url.lower().replace('http', 'ws') + return default_value + + @validate('ws_url') + def _ws_url_validate(self, proposal): + value = proposal['value'] + # Ensure value, if present, starts with 'ws' + if value is not None and len(value) > 0: + if not str(value).lower().startswith('ws'): + raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value) + return value + + kernels_endpoint_default_value = '/api/kernels' + kernels_endpoint_env = 'JUPYTER_GATEWAY_KERNELS_ENDPOINT' + kernels_endpoint = Unicode(default_value=kernels_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""") + + @default('kernels_endpoint') + def _kernels_endpoint_default(self): + return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value) + + kernelspecs_endpoint_default_value = '/api/kernelspecs' + kernelspecs_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT' + kernelspecs_endpoint = Unicode(default_value=kernelspecs_endpoint_default_value, config=True, + help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""") + + @default('kernelspecs_endpoint') + def _kernelspecs_endpoint_default(self): + return os.environ.get(self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value) + + connect_timeout_default_value = 20.0 + connect_timeout_env = 'JUPYTER_GATEWAY_CONNECT_TIMEOUT' + connect_timeout = Float(default_value=connect_timeout_default_value, config=True, + help="""The time allowed for HTTP connection establishment with the Gateway server. + (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""") + + @default('connect_timeout') + def connect_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_CONNECT_TIMEOUT', self.connect_timeout_default_value)) + + request_timeout_default_value = 20.0 + request_timeout_env = 'JUPYTER_GATEWAY_REQUEST_TIMEOUT' + request_timeout = Float(default_value=request_timeout_default_value, config=True, + help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""") + + @default('request_timeout') + def request_timeout_default(self): + return float(os.environ.get('JUPYTER_GATEWAY_REQUEST_TIMEOUT', self.request_timeout_default_value)) + + client_key = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var) + """ + ) + client_key_env = 'JUPYTER_GATEWAY_CLIENT_KEY' + + @default('client_key') + def _client_key_default(self): + return os.environ.get(self.client_key_env) + + client_cert = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var) + """ + ) + client_cert_env = 'JUPYTER_GATEWAY_CLIENT_CERT' + + @default('client_cert') + def _client_cert_default(self): + return os.environ.get(self.client_cert_env) + + ca_certs = Unicode(default_value=None, allow_none=True, config=True, + help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var) + """ + ) + ca_certs_env = 'JUPYTER_GATEWAY_CA_CERTS' + + @default('ca_certs') + def _ca_certs_default(self): + return os.environ.get(self.ca_certs_env) + + http_user = Unicode(default_value=None, allow_none=True, config=True, + help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var) + """ + ) + http_user_env = 'JUPYTER_GATEWAY_HTTP_USER' + + @default('http_user') + def _http_user_default(self): + return os.environ.get(self.http_user_env) + + http_pwd = Unicode(default_value=None, allow_none=True, config=True, + help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var) + """ + ) + http_pwd_env = 'JUPYTER_GATEWAY_HTTP_PWD' + + @default('http_pwd') + def _http_pwd_default(self): + return os.environ.get(self.http_pwd_env) + + headers_default_value = '{}' + headers_env = 'JUPYTER_GATEWAY_HEADERS' + headers = Unicode(default_value=headers_default_value, allow_none=True,config=True, + help="""Additional HTTP headers to pass on the request. This value will be converted to a dict. + (JUPYTER_GATEWAY_HEADERS env var) + """ + ) + + @default('headers') + def _headers_default(self): + return os.environ.get(self.headers_env, self.headers_default_value) + + auth_token = Unicode(default_value=None, allow_none=True, config=True, + help="""The authorization token used in the HTTP headers. (JUPYTER_GATEWAY_AUTH_TOKEN env var) + """ + ) + auth_token_env = 'JUPYTER_GATEWAY_AUTH_TOKEN' + + @default('auth_token') + def _auth_token_default(self): + return os.environ.get(self.auth_token_env) + + validate_cert_default_value = True + validate_cert_env = 'JUPYTER_GATEWAY_VALIDATE_CERT' + validate_cert = Bool(default_value=validate_cert_default_value, config=True, + help="""For HTTPS requests, determines if server's certificate should be validated or not. + (JUPYTER_GATEWAY_VALIDATE_CERT env var)""" + ) + + @default('validate_cert') + def validate_cert_default(self): + return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false']) + + def __init__(self, **kwargs): + super(GatewayClient, self).__init__(**kwargs) + self._static_args = {} # initialized on first use + + env_whitelist_default_value = '' + env_whitelist_env = 'JUPYTER_GATEWAY_ENV_WHITELIST' + env_whitelist = Unicode(default_value=env_whitelist_default_value, config=True, + help="""A comma-separated list of environment variable names that will be included, along with + their values, in the kernel startup request. The corresponding `env_whitelist` configuration + value must also be set on the Gateway server - since that configuration value indicates which + environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""") + + @default('env_whitelist') + def _env_whitelist_default(self): + return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value) + + @property + def gateway_enabled(self): + return bool(self.url is not None and len(self.url) > 0) + + def init_static_args(self): + """Initialize arguments used on every request. Since these are static values, we'll + perform this operation once. + + """ + self._static_args['headers'] = json.loads(self.headers) + self._static_args['headers'].update({'Authorization': 'token {}'.format(self.auth_token)}) + self._static_args['connect_timeout'] = self.connect_timeout + self._static_args['request_timeout'] = self.request_timeout + self._static_args['validate_cert'] = self.validate_cert + if self.client_cert: + self._static_args['client_cert'] = self.client_cert + self._static_args['client_key'] = self.client_key + if self.ca_certs: + self._static_args['ca_certs'] = self.ca_certs + if self.http_user: + self._static_args['auth_username'] = self.http_user + if self.http_pwd: + self._static_args['auth_password'] = self.http_pwd + + def load_connection_args(self, **kwargs): + """Merges the static args relative to the connection, with the given keyword arguments. If statics + have yet to be initialized, we'll do that here. + + """ + if len(self._static_args) == 0: + self.init_static_args() + + kwargs.update(self._static_args) + return kwargs + + +@gen.coroutine +def gateway_request(endpoint, **kwargs): + """Make an async request to kernel gateway endpoint, returns a response """ + client = AsyncHTTPClient() + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + try: + response = yield client.fetch(endpoint, **kwargs) + # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect + # or the server is not running. + # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes + # of the tree view. + except ConnectionRefusedError: + raise web.HTTPError(503, "Connection refused from Gateway server url '{}'. " + "Check to be sure the Gateway instance is running.".format(GatewayClient.instance().url)) + except HTTPTimeoutError: + # This can occur if the host is valid (e.g., foo.com) but there's nothing there. + raise web.HTTPError(504, "Timeout error attempting to connect to Gateway server url '{}'. " \ + "Ensure gateway url is valid and the Gateway instance is running.".format( + GatewayClient.instance().url)) + except gaierror as e: + raise web.HTTPError(404, "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. " + "Ensure gateway url is valid and the Gateway instance is running.".format( + GatewayClient.instance().url)) + + raise gen.Return(response) + + +class GatewayKernelManager(MappingKernelManager): + """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" + + # We'll maintain our own set of kernel ids + _kernels = {} + + def __init__(self, **kwargs): + super(GatewayKernelManager, self).__init__(**kwargs) + self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint) + + def __contains__(self, kernel_id): + return kernel_id in self._kernels + + def remove_kernel(self, kernel_id): + """Complete override since we want to be more tolerant of missing keys """ + try: + return self._kernels.pop(kernel_id) + except KeyError: + pass + + def _get_kernel_endpoint_url(self, kernel_id=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_id: kernel UUID (optional) + """ + if kernel_id: + return url_path_join(self.base_endpoint, url_escape(str(kernel_id))) + + return self.base_endpoint + + @gen.coroutine + def start_kernel(self, kernel_id=None, path=None, **kwargs): + """Start a kernel for a session and return its kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid to associate the new kernel with. If this + is not None, this kernel will be persistent whenever it is + requested. + path : API path + The API path (unicode, '/' delimited) for the cwd. + Will be transformed to an OS path relative to root_dir. + """ + self.log.info('Request start kernel: kernel_id=%s, path="%s"', kernel_id, path) + + if kernel_id is None: + if path is not None: + kwargs['cwd'] = self.cwd_for_path(path) + kernel_name = kwargs.get('kernel_name', 'python3') + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request new kernel at: %s" % kernel_url) + + # Let KERNEL_USERNAME take precedent over http_user config option. + if os.environ.get('KERNEL_USERNAME') is None and GatewayClient.instance().http_user: + os.environ['KERNEL_USERNAME'] = GatewayClient.instance().http_user + + kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_') + or k in GatewayClient.instance().env_whitelist.split(",")} + json_body = json_encode({'name': kernel_name, 'env': kernel_env}) + + response = yield gateway_request(kernel_url, method='POST', body=json_body) + kernel = json_decode(response.body) + kernel_id = kernel['id'] + self.log.info("Kernel started: %s" % kernel_id) + self.log.debug("Kernel args: %r" % kwargs) + else: + kernel = yield self.get_kernel(kernel_id) + kernel_id = kernel['id'] + self.log.info("Using existing kernel: %s" % kernel_id) + + self._kernels[kernel_id] = kernel + raise gen.Return(kernel_id) + + @gen.coroutine + def get_kernel(self, kernel_id=None, **kwargs): + """Get kernel for kernel_id. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request kernel at: %s" % kernel_url) + try: + response = yield gateway_request(kernel_url, method='GET') + except HTTPError as error: + if error.code == 404: + self.log.warn("Kernel not found at: %s" % kernel_url) + self.remove_kernel(kernel_id) + kernel = None + else: + raise + else: + kernel = json_decode(response.body) + self._kernels[kernel_id] = kernel + self.log.debug("Kernel retrieved: %s" % kernel) + raise gen.Return(kernel) + + @gen.coroutine + def kernel_model(self, kernel_id): + """Return a dictionary of kernel information described in the + JSON standard model. + + Parameters + ---------- + kernel_id : uuid + The uuid of the kernel. + """ + self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id) + model = yield self.get_kernel(kernel_id) + raise gen.Return(model) + + @gen.coroutine + def list_kernels(self, **kwargs): + """Get a list of kernels.""" + kernel_url = self._get_kernel_endpoint_url() + self.log.debug("Request list kernels: %s", kernel_url) + response = yield gateway_request(kernel_url, method='GET') + kernels = json_decode(response.body) + self._kernels = {x['id']:x for x in kernels} + raise gen.Return(kernels) + + @gen.coroutine + def shutdown_kernel(self, kernel_id, now=False, restart=False): + """Shutdown a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to shutdown. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request shutdown kernel at: %s", kernel_url) + response = yield gateway_request(kernel_url, method='DELETE') + self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason) + self.remove_kernel(kernel_id) + + @gen.coroutine + def restart_kernel(self, kernel_id, now=False, **kwargs): + """Restart a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to restart. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart' + self.log.debug("Request restart kernel at: %s", kernel_url) + response = yield gateway_request(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Restart kernel response: %d %s", response.code, response.reason) + + @gen.coroutine + def interrupt_kernel(self, kernel_id, **kwargs): + """Interrupt a kernel by its kernel uuid. + + Parameters + ========== + kernel_id : uuid + The id of the kernel to interrupt. + """ + kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt' + self.log.debug("Request interrupt kernel at: %s", kernel_url) + response = yield gateway_request(kernel_url, method='POST', body=json_encode({})) + self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason) + + def shutdown_all(self, now=False): + """Shutdown all kernels.""" + # Note: We have to make this sync because the NotebookApp does not wait for async. + shutdown_kernels = [] + kwargs = {'method': 'DELETE'} + kwargs = GatewayClient.instance().load_connection_args(**kwargs) + client = HTTPClient() + for kernel_id in self._kernels.keys(): + kernel_url = self._get_kernel_endpoint_url(kernel_id) + self.log.debug("Request delete kernel at: %s", kernel_url) + try: + response = client.fetch(kernel_url, **kwargs) + except HTTPError: + pass + else: + self.log.debug("Delete kernel response: %d %s", response.code, response.reason) + shutdown_kernels.append(kernel_id) # avoid changing dict size during iteration + client.close() + for kernel_id in shutdown_kernels: + self.remove_kernel(kernel_id) + + +class GatewayKernelSpecManager(KernelSpecManager): + + def __init__(self, **kwargs): + super(GatewayKernelSpecManager, self).__init__(**kwargs) + self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint) + + def _get_kernelspecs_endpoint_url(self, kernel_name=None): + """Builds a url for the kernels endpoint + + Parameters + ---------- + kernel_name: kernel name (optional) + """ + if kernel_name: + return url_path_join(self.base_endpoint, url_escape(kernel_name)) + + return self.base_endpoint + + @gen.coroutine + def get_all_specs(self): + fetched_kspecs = yield self.list_kernel_specs() + + # get the default kernel name and compare to that of this server. + # If different log a warning and reset the default. However, the + # caller of this method will still return this server's value until + # the next fetch of kernelspecs - at which time they'll match. + km = self.parent.kernel_manager + remote_default_kernel_name = fetched_kspecs.get('default') + if remote_default_kernel_name != km.default_kernel_name: + self.log.info("Default kernel name on Gateway server ({gateway_default}) differs from " + "Notebook server ({notebook_default}). Updating to Gateway server's value.". + format(gateway_default=remote_default_kernel_name, + notebook_default=km.default_kernel_name)) + km.default_kernel_name = remote_default_kernel_name + + # gateway doesn't support resources (requires transfer for use by NB client) + # so add `resource_dir` to each kernelspec and value of 'not supported in gateway mode' + remote_kspecs = fetched_kspecs.get('kernelspecs') + for kernel_name, kspec_info in remote_kspecs.items(): + if not kspec_info.get('resource_dir'): + kspec_info['resource_dir'] = 'not supported in gateway mode' + remote_kspecs[kernel_name].update(kspec_info) + + raise gen.Return(remote_kspecs) + + @gen.coroutine + def list_kernel_specs(self): + """Get a list of kernel specs.""" + kernel_spec_url = self._get_kernelspecs_endpoint_url() + self.log.debug("Request list kernel specs at: %s", kernel_spec_url) + response = yield gateway_request(kernel_spec_url, method='GET') + kernel_specs = json_decode(response.body) + raise gen.Return(kernel_specs) + + @gen.coroutine + def get_kernel_spec(self, kernel_name, **kwargs): + """Get kernel spec for kernel_name. + + Parameters + ---------- + kernel_name : str + The name of the kernel. + """ + kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name)) + self.log.debug("Request kernel spec at: %s" % kernel_spec_url) + try: + response = yield gateway_request(kernel_spec_url, method='GET') + except HTTPError as error: + if error.code == 404: + # Convert not found to KeyError since that's what the Notebook handler expects + # message is not used, but might as well make it useful for troubleshooting + raise KeyError('kernelspec {kernel_name} not found on Gateway server at: {gateway_url}'. + format(kernel_name=kernel_name, gateway_url=GatewayClient.instance().url)) + else: + raise + else: + kernel_spec = json_decode(response.body) + # Convert to instance of Kernelspec + kspec_instance = self.kernel_spec_class(resource_dir=u'', **kernel_spec['spec']) + raise gen.Return(kspec_instance) + + +class GatewaySessionManager(SessionManager): + kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager') + + @gen.coroutine + def kernel_culled(self, kernel_id): + """Checks if the kernel is still considered alive and returns true if its not found. """ + kernel = yield self.kernel_manager.get_kernel(kernel_id) + raise gen.Return(kernel is None) diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py index 05e44cf29f..2639b4faa8 100755 --- a/notebook/notebookapp.py +++ b/notebook/notebookapp.py @@ -84,6 +84,7 @@ from .services.contents.filemanager import FileContentsManager from .services.contents.largefilemanager import LargeFileManager from .services.sessions.sessionmanager import SessionManager +from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient from .auth.login import LoginHandler from .auth.logout import LogoutHandler @@ -96,7 +97,7 @@ ) from jupyter_core.paths import jupyter_config_path from jupyter_client import KernelManager -from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME +from jupyter_client.kernelspec import KernelSpecManager from jupyter_client.session import Session from nbformat.sign import NotebookNotary from traitlets import ( @@ -144,6 +145,7 @@ def load_handlers(name): # The Tornado web application #----------------------------------------------------------------------------- + class NotebookWebApplication(web.Application): def __init__(self, jupyter_app, kernel_manager, contents_manager, @@ -151,7 +153,6 @@ def __init__(self, jupyter_app, kernel_manager, contents_manager, config_manager, extra_services, log, base_url, default_url, settings_overrides, jinja_env_options): - settings = self.init_settings( jupyter_app, kernel_manager, contents_manager, session_manager, kernel_spec_manager, config_manager, @@ -305,15 +306,27 @@ def init_handlers(self, settings): handlers.extend(load_handlers('notebook.edit.handlers')) handlers.extend(load_handlers('notebook.services.api.handlers')) handlers.extend(load_handlers('notebook.services.config.handlers')) - handlers.extend(load_handlers('notebook.services.kernels.handlers')) handlers.extend(load_handlers('notebook.services.contents.handlers')) handlers.extend(load_handlers('notebook.services.sessions.handlers')) handlers.extend(load_handlers('notebook.services.nbconvert.handlers')) - handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) handlers.extend(load_handlers('notebook.services.security.handlers')) handlers.extend(load_handlers('notebook.services.shutdown')) + handlers.extend(load_handlers('notebook.services.kernels.handlers')) + handlers.extend(load_handlers('notebook.services.kernelspecs.handlers')) + handlers.extend(settings['contents_manager'].get_extra_handlers()) + # If gateway mode is enabled, replace appropriate handlers to perform redirection + if GatewayClient.instance().gateway_enabled: + # for each handler required for gateway, locate its pattern + # in the current list and replace that entry... + gateway_handlers = load_handlers('notebook.gateway.handlers') + for i, gwh in enumerate(gateway_handlers): + for j, h in enumerate(handlers): + if gwh[0] == h[0]: + handlers[j] = (gwh[0], gwh[1]) + break + handlers.append( (r"/nbextensions/(.*)", FileFindHandler, { 'path': settings['nbextensions_path'], @@ -547,6 +560,7 @@ def start(self): 'notebook-dir': 'NotebookApp.notebook_dir', 'browser': 'NotebookApp.browser', 'pylab': 'NotebookApp.pylab', + 'gateway-url': 'GatewayClient.url', }) #----------------------------------------------------------------------------- @@ -565,9 +579,9 @@ class NotebookApp(JupyterApp): flags = flags classes = [ - KernelManager, Session, MappingKernelManager, + KernelManager, Session, MappingKernelManager, KernelSpecManager, ContentsManager, FileContentsManager, NotebookNotary, - KernelSpecManager, + GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient, ] flags = Dict(flags) aliases = Dict(aliases) @@ -1316,6 +1330,16 @@ def parse_command_line(self, argv=None): self.update_config(c) def init_configurables(self): + + # If gateway server is configured, replace appropriate managers to perform redirection. To make + # this determination, instantiate the GatewayClient config singleton. + self.gateway_config = GatewayClient.instance(parent=self) + + if self.gateway_config.gateway_enabled: + self.kernel_manager_class = 'notebook.gateway.managers.GatewayKernelManager' + self.session_manager_class = 'notebook.gateway.managers.GatewaySessionManager' + self.kernel_spec_manager_class = 'notebook.gateway.managers.GatewayKernelSpecManager' + self.kernel_spec_manager = self.kernel_spec_manager_class( parent=self, ) @@ -1661,6 +1685,8 @@ def notebook_info(self, kernel_count=True): info += "\n" # Format the info so that the URL fits on a single line in 80 char display info += _("The Jupyter Notebook is running at:\n%s") % self.display_url + if self.gateway_config.gateway_enabled: + info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_config.url return info def server_info(self): diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py index cfef2a4a0e..897fa51db2 100644 --- a/notebook/services/kernels/handlers.py +++ b/notebook/services/kernels/handlers.py @@ -45,7 +45,7 @@ def post(self): model.setdefault('name', km.default_kernel_name) kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name'])) - model = km.kernel_model(kernel_id) + model = yield gen.maybe_future(km.kernel_model(kernel_id)) location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id)) self.set_header('Location', location) self.set_status(201) @@ -57,7 +57,6 @@ class KernelHandler(APIHandler): @web.authenticated def get(self, kernel_id): km = self.kernel_manager - km._check_kernel_id(kernel_id) model = km.kernel_model(kernel_id) self.finish(json.dumps(model, default=date_default)) @@ -87,7 +86,7 @@ def post(self, kernel_id, action): self.log.error("Exception restarting kernel", exc_info=True) self.set_status(500) else: - model = km.kernel_model(kernel_id) + model = yield gen.maybe_future(km.kernel_model(kernel_id)) self.write(json.dumps(model, default=date_default)) self.finish() diff --git a/notebook/services/kernelspecs/handlers.py b/notebook/services/kernelspecs/handlers.py index d272db2f71..c0157e4c57 100644 --- a/notebook/services/kernelspecs/handlers.py +++ b/notebook/services/kernelspecs/handlers.py @@ -11,7 +11,7 @@ import os pjoin = os.path.join -from tornado import web +from tornado import web, gen from ...base.handlers import APIHandler from ...utils import url_path_join, url_unescape @@ -48,13 +48,15 @@ def kernelspec_model(handler, name, spec_dict, resource_dir): class MainKernelSpecHandler(APIHandler): @web.authenticated + @gen.coroutine def get(self): ksm = self.kernel_spec_manager km = self.kernel_manager model = {} model['default'] = km.default_kernel_name model['kernelspecs'] = specs = {} - for kernel_name, kernel_info in ksm.get_all_specs().items(): + kspecs = yield gen.maybe_future(ksm.get_all_specs()) + for kernel_name, kernel_info in kspecs.items(): try: d = kernelspec_model(self, kernel_name, kernel_info['spec'], kernel_info['resource_dir']) @@ -69,11 +71,12 @@ def get(self): class KernelSpecHandler(APIHandler): @web.authenticated + @gen.coroutine def get(self, kernel_name): ksm = self.kernel_spec_manager kernel_name = url_unescape(kernel_name) try: - spec = ksm.get_kernel_spec(kernel_name) + spec = yield gen.maybe_future(ksm.get_kernel_spec(kernel_name)) except KeyError: raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name) model = kernelspec_model(self, kernel_name, spec.to_dict(), spec.resource_dir) diff --git a/notebook/services/sessions/sessionmanager.py b/notebook/services/sessions/sessionmanager.py index ee70eb0810..4497cfbc33 100644 --- a/notebook/services/sessions/sessionmanager.py +++ b/notebook/services/sessions/sessionmanager.py @@ -56,21 +56,22 @@ def __del__(self): """Close connection once SessionManager closes""" self.close() + @gen.coroutine def session_exists(self, path): """Check to see if the session of a given name exists""" + exists = False self.cursor.execute("SELECT * FROM session WHERE path=?", (path,)) row = self.cursor.fetchone() - if row is None: - return False - else: + if row is not None: # Note, although we found a row for the session, the associated kernel may have # been culled or died unexpectedly. If that's the case, we should delete the # row, thereby terminating the session. This can be done via a call to # row_to_model that tolerates that condition. If row_to_model returns None, # we'll return false, since, at that point, the session doesn't exist anyway. - if self.row_to_model(row, tolerate_culled=True) is None: - return False - return True + model = yield gen.maybe_future(self.row_to_model(row, tolerate_culled=True)) + if model is not None: + exists = True + raise gen.Return(exists) def new_session_id(self): "Create a uuid for a new session" @@ -101,6 +102,7 @@ def start_kernel_for_session(self, session_id, path, name, type, kernel_name): # py2-compat raise gen.Return(kernel_id) + @gen.coroutine def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None): """Saves the items for the session with the given session_id @@ -129,8 +131,10 @@ def save_session(self, session_id, path=None, name=None, type=None, kernel_id=No self.cursor.execute("INSERT INTO session VALUES (?,?,?,?,?)", (session_id, path, name, type, kernel_id) ) - return self.get_session(session_id=session_id) + result = yield gen.maybe_future(self.get_session(session_id=session_id)) + raise gen.Return(result) + @gen.coroutine def get_session(self, **kwargs): """Returns the model for a particular session. @@ -174,8 +178,10 @@ def get_session(self, **kwargs): raise web.HTTPError(404, u'Session not found: %s' % (', '.join(q))) - return self.row_to_model(row) + model = yield gen.maybe_future(self.row_to_model(row)) + raise gen.Return(model) + @gen.coroutine def update_session(self, session_id, **kwargs): """Updates the values in the session database. @@ -191,7 +197,7 @@ def update_session(self, session_id, **kwargs): and the value replaces the current value in the session with session_id. """ - self.get_session(session_id=session_id) + yield gen.maybe_future(self.get_session(session_id=session_id)) if not kwargs: # no changes @@ -205,9 +211,15 @@ def update_session(self, session_id, **kwargs): query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets)) self.cursor.execute(query, list(kwargs.values()) + [session_id]) + def kernel_culled(self, kernel_id): + """Checks if the kernel is still considered alive and returns true if its not found. """ + return kernel_id not in self.kernel_manager + + @gen.coroutine def row_to_model(self, row, tolerate_culled=False): """Takes sqlite database session row and turns it into a dictionary""" - if row['kernel_id'] not in self.kernel_manager: + kernel_culled = yield gen.maybe_future(self.kernel_culled(row['kernel_id'])) + if kernel_culled: # The kernel was culled or died without deleting the session. # We can't use delete_session here because that tries to find # and shut down the kernel - so we'll delete the row directly. @@ -222,21 +234,23 @@ def row_to_model(self, row, tolerate_culled=False): format(kernel_id=row['kernel_id'],session_id=row['session_id']) if tolerate_culled: self.log.warning(msg + " Continuing...") - return None + raise gen.Return(None) raise KeyError(msg) + kernel_model = yield gen.maybe_future(self.kernel_manager.kernel_model(row['kernel_id'])) model = { 'id': row['session_id'], 'path': row['path'], 'name': row['name'], 'type': row['type'], - 'kernel': self.kernel_manager.kernel_model(row['kernel_id']) + 'kernel': kernel_model } if row['type'] == 'notebook': # Provide the deprecated API. model['notebook'] = {'path': row['path'], 'name': row['name']} - return model + raise gen.Return(model) + @gen.coroutine def list_sessions(self): """Returns a list of dictionaries containing all the information from the session database""" @@ -246,14 +260,15 @@ def list_sessions(self): # which messes up the cursor if we're iterating over rows. for row in c.fetchall(): try: - result.append(self.row_to_model(row)) + model = yield gen.maybe_future(self.row_to_model(row)) + result.append(model) except KeyError: pass - return result + raise gen.Return(result) @gen.coroutine def delete_session(self, session_id): """Deletes the row in the session database with given session_id""" - session = self.get_session(session_id=session_id) + session = yield gen.maybe_future(self.get_session(session_id=session_id)) yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id'])) self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) diff --git a/notebook/services/sessions/tests/test_sessionmanager.py b/notebook/services/sessions/tests/test_sessionmanager.py index 96847a868a..97331ebf9b 100644 --- a/notebook/services/sessions/tests/test_sessionmanager.py +++ b/notebook/services/sessions/tests/test_sessionmanager.py @@ -62,11 +62,11 @@ def co_add(): def create_session(self, **kwargs): return self.create_sessions(kwargs)[0] - + def test_get_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='bar')['id'] - model = sm.get_session(session_id=session_id) + model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id)) expected = {'id':session_id, 'path': u'/path/to/test.ipynb', 'notebook': {'path': u'/path/to/test.ipynb', 'name': None}, @@ -86,7 +86,8 @@ def test_bad_get_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='foo')['id'] - self.assertRaises(TypeError, sm.get_session, bad_id=session_id) # Bad keyword + with self.assertRaises(TypeError): + self.loop.run_sync(lambda: sm.get_session(bad_id=session_id)) # Bad keyword def test_get_session_dead_kernel(self): sm = self.sm @@ -94,9 +95,9 @@ def test_get_session_dead_kernel(self): # kill the kernel sm.kernel_manager.shutdown_kernel(session['kernel']['id']) with self.assertRaises(KeyError): - sm.get_session(session_id=session['id']) + self.loop.run_sync(lambda: sm.get_session(session_id=session['id'])) # no sessions left - listed = sm.list_sessions() + listed = self.loop.run_sync(lambda: sm.list_sessions()) self.assertEqual(listed, []) def test_list_sessions(self): @@ -107,7 +108,7 @@ def test_list_sessions(self): dict(path='/path/to/3', name='foo', type='console', kernel_name='python'), ) - sessions = sm.list_sessions() + sessions = self.loop.run_sync(lambda: sm.list_sessions()) expected = [ { 'id':sessions[0]['id'], @@ -158,7 +159,7 @@ def test_list_sessions_dead_kernel(self): ) # kill one of the kernels sm.kernel_manager.shutdown_kernel(sessions[0]['kernel']['id']) - listed = sm.list_sessions() + listed = self.loop.run_sync(lambda: sm.list_sessions()) expected = [ { 'id': sessions[1]['id'], @@ -181,8 +182,8 @@ def test_update_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='julia')['id'] - sm.update_session(session_id, path='/path/to/new_name.ipynb') - model = sm.get_session(session_id=session_id) + self.loop.run_sync(lambda: sm.update_session(session_id, path='/path/to/new_name.ipynb')) + model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id)) expected = {'id':session_id, 'path': u'/path/to/new_name.ipynb', 'type': 'notebook', @@ -203,7 +204,8 @@ def test_bad_update_session(self): sm = self.sm session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='ir')['id'] - self.assertRaises(TypeError, sm.update_session, session_id=session_id, bad_kw='test.ipynb') # Bad keyword + with self.assertRaises(TypeError): + self.loop.run_sync(lambda: sm.update_session(session_id=session_id, bad_kw='test.ipynb')) # Bad keyword def test_delete_session(self): sm = self.sm @@ -212,8 +214,8 @@ def test_delete_session(self): dict(path='/path/to/2/test2.ipynb', kernel_name='python'), dict(path='/path/to/3', name='foo', type='console', kernel_name='python'), ) - sm.delete_session(sessions[1]['id']) - new_sessions = sm.list_sessions() + self.loop.run_sync(lambda: sm.delete_session(sessions[1]['id'])) + new_sessions = self.loop.run_sync(lambda: sm.list_sessions()) expected = [{ 'id': sessions[0]['id'], 'path': u'/path/to/1/test1.ipynb', diff --git a/notebook/tests/launchnotebook.py b/notebook/tests/launchnotebook.py index 1b685df0ca..9e84a5964b 100644 --- a/notebook/tests/launchnotebook.py +++ b/notebook/tests/launchnotebook.py @@ -91,6 +91,22 @@ def request(cls, verb, path, **kwargs): url_path_join(cls.base_url(), path), **kwargs) return response + + @classmethod + def get_patch_env(cls): + return { + 'HOME': cls.home_dir, + 'PYTHONPATH': os.pathsep.join(sys.path), + 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'), + 'JUPYTER_NO_CONFIG': '1', # needed in the future + 'JUPYTER_CONFIG_DIR' : cls.config_dir, + 'JUPYTER_DATA_DIR' : cls.data_dir, + 'JUPYTER_RUNTIME_DIR': cls.runtime_dir, + } + + @classmethod + def get_argv(cls): + return [] @classmethod def setup_class(cls): @@ -109,15 +125,7 @@ def tmp(*parts): config_dir = cls.config_dir = tmp('config') runtime_dir = cls.runtime_dir = tmp('runtime') cls.notebook_dir = tmp('notebooks') - cls.env_patch = patch.dict('os.environ', { - 'HOME': cls.home_dir, - 'PYTHONPATH': os.pathsep.join(sys.path), - 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'), - 'JUPYTER_NO_CONFIG': '1', # needed in the future - 'JUPYTER_CONFIG_DIR' : config_dir, - 'JUPYTER_DATA_DIR' : data_dir, - 'JUPYTER_RUNTIME_DIR': runtime_dir, - }) + cls.env_patch = patch.dict('os.environ', cls.get_patch_env()) cls.env_patch.start() cls.path_patch = patch.multiple( jupyter_core.paths, @@ -157,7 +165,7 @@ def start_thread(): # needs to be redone after initialize, which reconfigures logging app.log.propagate = True app.log.handlers = [] - app.initialize(argv=[]) + app.initialize(argv=cls.get_argv()) app.log.propagate = True app.log.handlers = [] loop = IOLoop.current() diff --git a/notebook/tests/test_gateway.py b/notebook/tests/test_gateway.py new file mode 100644 index 0000000000..ef3cd7ef56 --- /dev/null +++ b/notebook/tests/test_gateway.py @@ -0,0 +1,354 @@ +"""Test GatewayClient""" +import os +import json +import uuid +from datetime import datetime +from tornado import gen +from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError +from traitlets.config import Config +from .launchnotebook import NotebookTestBase +from notebook.gateway.managers import GatewayClient + +try: + from unittest.mock import patch, Mock +except ImportError: + from mock import patch, Mock # py2 + +try: + from io import StringIO +except ImportError: + import StringIO + +import nose.tools as nt + + +def generate_kernelspec(name): + argv_stanza = ['python', '-m', 'ipykernel_launcher', '-f', '{connection_file}'] + spec_stanza = {'spec': {'argv': argv_stanza, 'env': {}, 'display_name': name, 'language': 'python', 'interrupt_mode': 'signal', 'metadata': {}}} + kernelspec_stanza = {'name': name, 'spec': spec_stanza, 'resources': {}} + return kernelspec_stanza + + +# We'll mock up two kernelspecs - kspec_foo and kspec_bar +kernelspecs = {'default': 'kspec_foo', 'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}} + + +# maintain a dictionary of expected running kernels. Key = kernel_id, Value = model. +running_kernels = dict() + + +def generate_model(name): + """Generate a mocked kernel model. Caller is responsible for adding model to running_kernels dictionary.""" + dt = datetime.utcnow().isoformat() + 'Z' + kernel_id = str(uuid.uuid4()) + model = {'id': kernel_id, 'name': name, 'last_activity': str(dt), 'execution_state': 'idle', 'connections': 1} + return model + + +@gen.coroutine +def mock_gateway_request(url, **kwargs): + method = 'GET' + if kwargs['method']: + method = kwargs['method'] + + request = HTTPRequest(url=url, **kwargs) + + endpoint = str(url) + + # Fetch all kernelspecs + if endpoint.endswith('/api/kernelspecs') and method == 'GET': + response_buf = StringIO(json.dumps(kernelspecs)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Fetch named kernelspec + if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET': + requested_kernelspec = endpoint.rpartition('/')[2] + kspecs = kernelspecs.get('kernelspecs') + if requested_kernelspec in kspecs: + response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec) + + # Create kernel + if endpoint.endswith('/api/kernels') and method == 'POST': + json_body = json.loads(kwargs['body']) + name = json_body.get('name') + env = json_body.get('env') + kspec_name = env.get('KERNEL_KSPEC_NAME') + nt.assert_equal(name, kspec_name) # Ensure that KERNEL_ env values get propagated + model = generate_model(name) + running_kernels[model.get('id')] = model # Register model as a running kernel + response_buf = StringIO(json.dumps(model)) + response = yield gen.maybe_future(HTTPResponse(request, 201, buffer=response_buf)) + raise gen.Return(response) + + # Fetch list of running kernels + if endpoint.endswith('/api/kernels') and method == 'GET': + kernels = [] + for kernel_id in running_kernels.keys(): + model = running_kernels.get(kernel_id) + kernels.append(model) + response_buf = StringIO(json.dumps(kernels)) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + + # Interrupt or restart existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST': + requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/') + + if action == 'interrupt': + if requested_kernel_id in running_kernels: + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + elif action == 'restart': + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 204, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + else: + raise HTTPError(404, message='Bad action detected: %s' % action) + + # Shutdown existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE': + requested_kernel_id = endpoint.rpartition('/')[2] + running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set + response = yield gen.maybe_future(HTTPResponse(request, 204)) + raise gen.Return(response) + + # Fetch existing kernel + if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET': + requested_kernel_id = endpoint.rpartition('/')[2] + if requested_kernel_id in running_kernels: + response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf)) + raise gen.Return(response) + else: + raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id) + + +mocked_gateway = patch('notebook.gateway.managers.gateway_request', mock_gateway_request) + + +class TestGateway(NotebookTestBase): + + mock_gateway_url = 'http://mock-gateway-server:8889' + mock_http_user = 'alice' + + @classmethod + def setup_class(cls): + GatewayClient.clear_instance() + super(TestGateway, cls).setup_class() + + @classmethod + def teardown_class(cls): + GatewayClient.clear_instance() + super(TestGateway, cls).teardown_class() + + @classmethod + def get_patch_env(cls): + test_env = super(TestGateway, cls).get_patch_env() + test_env.update({'JUPYTER_GATEWAY_URL': TestGateway.mock_gateway_url, + 'JUPYTER_GATEWAY_REQUEST_TIMEOUT': '44.4'}) + return test_env + + @classmethod + def get_argv(cls): + argv = super(TestGateway, cls).get_argv() + argv.extend(['--GatewayClient.connect_timeout=44.4', '--GatewayClient.http_user=' + TestGateway.mock_http_user]) + return argv + + def test_gateway_options(self): + nt.assert_equal(self.notebook.gateway_config.gateway_enabled, True) + nt.assert_equal(self.notebook.gateway_config.url, TestGateway.mock_gateway_url) + nt.assert_equal(self.notebook.gateway_config.http_user, TestGateway.mock_http_user) + nt.assert_equal(self.notebook.gateway_config.connect_timeout, self.notebook.gateway_config.connect_timeout) + nt.assert_equal(self.notebook.gateway_config.connect_timeout, 44.4) + + def test_gateway_class_mappings(self): + # Ensure appropriate class mappings are in place. + nt.assert_equal(self.notebook.kernel_manager_class.__name__, 'GatewayKernelManager') + nt.assert_equal(self.notebook.session_manager_class.__name__, 'GatewaySessionManager') + nt.assert_equal(self.notebook.kernel_spec_manager_class.__name__, 'GatewayKernelSpecManager') + + def test_gateway_get_kernelspecs(self): + # Validate that kernelspecs come from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs') + self.assertEqual(response.status_code, 200) + content = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kspecs = content.get('kernelspecs') + self.assertEqual(len(kspecs), 2) + self.assertEqual(kspecs.get('kspec_bar').get('name'), 'kspec_bar') + + def test_gateway_get_named_kernelspec(self): + # Validate that a specific kernelspec can be retrieved from gateway. + with mocked_gateway: + response = self.request('GET', '/api/kernelspecs/kspec_foo') + self.assertEqual(response.status_code, 200) + kspec_foo = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(kspec_foo.get('name'), 'kspec_foo') + + response = self.request('GET', '/api/kernelspecs/no_such_spec') + self.assertEqual(response.status_code, 404) + + def test_gateway_session_lifecycle(self): + # Validate session lifecycle functions; create and delete. + + # create + session_id, kernel_id = self.create_session('kspec_foo') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_session(session_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def test_gateway_kernel_lifecycle(self): + # Validate kernel lifecycle functions; create, interrupt, restart and delete. + + # create + kernel_id = self.create_kernel('kspec_bar') + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # interrupt + self.interrupt_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # restart + self.restart_kernel(kernel_id) + + # ensure kernel still considered running + self.assertTrue(self.is_kernel_running(kernel_id)) + + # delete + self.delete_kernel(kernel_id) + self.assertFalse(self.is_kernel_running(kernel_id)) + + def create_session(self, kernel_name): + """Creates a session for a kernel. The session is created against the notebook server + which then uses the gateway for kernel management. + """ + with mocked_gateway: + nb_path = os.path.join(self.notebook_dir, 'testgw.ipynb') + kwargs = dict() + kwargs['json'] = {'path': nb_path, 'type': 'notebook', 'kernel': {'name': kernel_name}} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + # Create the kernel... (also tests get_kernel) + response = self.request('POST', '/api/sessions', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(model.get('path'), nb_path) + kernel_id = model.get('kernel').get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('kernel').get('name'), running_kernel.get('name')) + session_id = model.get('id') + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return session_id, kernel_id + + def delete_session(self, session_id): + """Deletes a session corresponding to the given session id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/sessions/' + session_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def is_kernel_running(self, kernel_id): + """Issues request to get the set of running kernels + """ + with mocked_gateway: + # Get list of running kernels + response = self.request('GET', '/api/kernels') + self.assertEqual(response.status_code, 200) + kernels = json.loads(response.content.decode('utf-8'), encoding='utf-8') + self.assertEqual(len(kernels), len(running_kernels)) + for model in kernels: + if model.get('id') == kernel_id: + return True + return False + + def create_kernel(self, kernel_name): + """Issues request to retart the given kernel + """ + with mocked_gateway: + kwargs = dict() + kwargs['json'] = {'name': kernel_name} + + # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method + os.environ['KERNEL_KSPEC_NAME'] = kernel_name + + response = self.request('POST', '/api/kernels', **kwargs) + self.assertEqual(response.status_code, 201) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(kernel_id) + self.assertEqual(kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), kernel_name) + + # restore env + os.environ.pop('KERNEL_KSPEC_NAME') + return kernel_id + + def interrupt_kernel(self, kernel_id): + """Issues request to interrupt the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/interrupt') + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') + + def restart_kernel(self, kernel_id): + """Issues request to retart the given kernel + """ + with mocked_gateway: + response = self.request('POST', '/api/kernels/' + kernel_id + '/restart') + self.assertEqual(response.status_code, 200) + model = json.loads(response.content.decode('utf-8'), encoding='utf-8') + restarted_kernel_id = model.get('id') + # ensure its in the running_kernels and name matches. + running_kernel = running_kernels.get(restarted_kernel_id) + self.assertEqual(restarted_kernel_id, running_kernel.get('id')) + self.assertEqual(model.get('name'), running_kernel.get('name')) + + def delete_kernel(self, kernel_id): + """Deletes kernel corresponding to the given kernel id. + """ + with mocked_gateway: + # Delete the session (and kernel) + response = self.request('DELETE', '/api/kernels/' + kernel_id) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.reason, 'No Content') +