diff --git a/docs/source/public_server.rst b/docs/source/public_server.rst
index 3796a2a4fb..edadbe3ffc 100644
--- a/docs/source/public_server.rst
+++ b/docs/source/public_server.rst
@@ -343,6 +343,35 @@ single-tab mode:
});
+Using a gateway server for kernel management
+--------------------------------------------
+
+You are now able to redirect the management of your kernels to a Gateway Server
+(i.e., `Jupyter Kernel Gateway `_ or
+`Jupyter Enterprise Gateway `_)
+simply by specifying a Gateway url via the following command-line option:
+
+ .. code-block:: bash
+
+ $ jupyter notebook --gateway-url=http://my-gateway-server:8888
+
+the environment:
+
+ .. code-block:: bash
+
+ JUPYTER_GATEWAY_URL=http://my-gateway-server:8888
+
+or in :file:`jupyter_notebook_config.py`:
+
+ .. code-block:: python
+
+ c.GatewayClient.url = http://my-gateway-server:8888
+
+When provided, all kernel specifications will be retrieved from the specified Gateway server and all
+kernels will be managed by that server. This option enables the ability to target kernel processes
+against managed clusters while allowing for the notebook's management to remain local to the Notebook
+server.
+
Known issues
------------
diff --git a/notebook/gateway/__init__.py b/notebook/gateway/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/notebook/gateway/handlers.py b/notebook/gateway/handlers.py
new file mode 100644
index 0000000000..8e09b10861
--- /dev/null
+++ b/notebook/gateway/handlers.py
@@ -0,0 +1,207 @@
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import os
+import logging
+
+from ..base.handlers import IPythonHandler
+from ..utils import url_path_join
+
+from tornado import gen, web
+from tornado.concurrent import Future
+from tornado.ioloop import IOLoop
+from tornado.websocket import WebSocketHandler, websocket_connect
+from tornado.httpclient import HTTPRequest
+from tornado.escape import url_escape, json_decode, utf8
+
+from ipython_genutils.py3compat import cast_unicode
+from jupyter_client.session import Session
+from traitlets.config.configurable import LoggingConfigurable
+
+from .managers import GatewayClient
+
+
+class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler):
+
+ session = None
+ gateway = None
+ kernel_id = None
+
+ def set_default_headers(self):
+ """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets"""
+ pass
+
+ def get_compression_options(self):
+ # use deflate compress websocket
+ return {}
+
+ def authenticate(self):
+ """Run before finishing the GET request
+
+ Extend this method to add logic that should fire before
+ the websocket finishes completing.
+ """
+ # authenticate the request before opening the websocket
+ if self.get_current_user() is None:
+ self.log.warning("Couldn't authenticate WebSocket connection")
+ raise web.HTTPError(403)
+
+ if self.get_argument('session_id', False):
+ self.session.session = cast_unicode(self.get_argument('session_id'))
+ else:
+ self.log.warning("No session ID specified")
+
+ def initialize(self):
+ self.log.debug("Initializing websocket connection %s", self.request.path)
+ self.session = Session(config=self.config)
+ self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
+
+ @gen.coroutine
+ def get(self, kernel_id, *args, **kwargs):
+ self.authenticate()
+ self.kernel_id = cast_unicode(kernel_id, 'ascii')
+ super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs)
+
+ def open(self, kernel_id, *args, **kwargs):
+ """Handle web socket connection open to notebook server and delegate to gateway web socket handler """
+ self.gateway.on_open(
+ kernel_id=kernel_id,
+ message_callback=self.write_message,
+ compression_options=self.get_compression_options()
+ )
+
+ def on_message(self, message):
+ """Forward message to gateway web socket handler."""
+ self.log.debug("Sending message to gateway: {}".format(message))
+ self.gateway.on_message(message)
+
+ def write_message(self, message, binary=False):
+ """Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
+ self.log.debug("Receiving message from gateway: {}".format(message))
+ if self.ws_connection: # prevent WebSocketClosedError
+ super(WebSocketChannelsHandler, self).write_message(message, binary=binary)
+ elif self.log.isEnabledFor(logging.DEBUG):
+ msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
+ self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary))
+
+ def on_close(self):
+ self.log.debug("Closing websocket connection %s", self.request.path)
+ self.gateway.on_close()
+ super(WebSocketChannelsHandler, self).on_close()
+
+ @staticmethod
+ def _get_message_summary(message):
+ summary = []
+ message_type = message['msg_type']
+ summary.append('type: {}'.format(message_type))
+
+ if message_type == 'status':
+ summary.append(', state: {}'.format(message['content']['execution_state']))
+ elif message_type == 'error':
+ summary.append(', {}:{}:{}'.format(message['content']['ename'],
+ message['content']['evalue'],
+ message['content']['traceback']))
+ else:
+ summary.append(', ...') # don't display potentially sensitive data
+
+ return ''.join(summary)
+
+
+class GatewayWebSocketClient(LoggingConfigurable):
+ """Proxy web socket connection to a kernel/enterprise gateway."""
+
+ def __init__(self, **kwargs):
+ super(GatewayWebSocketClient, self).__init__(**kwargs)
+ self.kernel_id = None
+ self.ws = None
+ self.ws_future = Future()
+ self.ws_future_cancelled = False
+
+ @gen.coroutine
+ def _connect(self, kernel_id):
+ self.kernel_id = kernel_id
+ ws_url = url_path_join(
+ GatewayClient.instance().ws_url,
+ GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels'
+ )
+ self.log.info('Connecting to {}'.format(ws_url))
+ kwargs = {}
+ kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+
+ request = HTTPRequest(ws_url, **kwargs)
+ self.ws_future = websocket_connect(request)
+ self.ws_future.add_done_callback(self._connection_done)
+
+ def _connection_done(self, fut):
+ if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError
+ self.ws = fut.result()
+ self.log.debug("Connection is ready: ws: {}".format(self.ws))
+ else:
+ self.log.warning("Websocket connection has been cancelled via client disconnect before its establishment. "
+ "Kernel with ID '{}' may not be terminated on GatewayClient: {}".
+ format(self.kernel_id, GatewayClient.instance().url))
+
+ def _disconnect(self):
+ if self.ws is not None:
+ # Close connection
+ self.ws.close()
+ elif not self.ws_future.done():
+ # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
+ self.ws_future.cancel()
+ self.ws_future_cancelled = True
+ self.log.debug("_disconnect: ws_future_cancelled: {}".format(self.ws_future_cancelled))
+
+ @gen.coroutine
+ def _read_messages(self, callback):
+ """Read messages from gateway server."""
+ while True:
+ message = None
+ if not self.ws_future_cancelled:
+ try:
+ message = yield self.ws.read_message()
+ except Exception as e:
+ self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True)
+ if message is None:
+ break
+ callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
+ else: # ws cancelled - stop reading
+ break
+
+ def on_open(self, kernel_id, message_callback, **kwargs):
+ """Web socket connection open against gateway server."""
+ self._connect(kernel_id)
+ loop = IOLoop.current()
+ loop.add_future(
+ self.ws_future,
+ lambda future: self._read_messages(message_callback)
+ )
+
+ def on_message(self, message):
+ """Send message to gateway server."""
+ if self.ws is None:
+ loop = IOLoop.current()
+ loop.add_future(
+ self.ws_future,
+ lambda future: self._write_message(message)
+ )
+ else:
+ self._write_message(message)
+
+ def _write_message(self, message):
+ """Send message to gateway server."""
+ try:
+ if not self.ws_future_cancelled:
+ self.ws.write_message(message)
+ except Exception as e:
+ self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True)
+
+ def on_close(self):
+ """Web socket closed event."""
+ self._disconnect()
+
+
+from ..services.kernels.handlers import _kernel_id_regex
+
+default_handlers = [
+ (r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
+]
diff --git a/notebook/gateway/managers.py b/notebook/gateway/managers.py
new file mode 100644
index 0000000000..73af7d9799
--- /dev/null
+++ b/notebook/gateway/managers.py
@@ -0,0 +1,555 @@
+# Copyright (c) Jupyter Development Team.
+# Distributed under the terms of the Modified BSD License.
+
+import os
+import json
+
+from socket import gaierror
+from tornado import gen, web
+from tornado.escape import json_encode, json_decode, url_escape
+from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError
+from tornado.simple_httpclient import HTTPTimeoutError
+
+from ..services.kernels.kernelmanager import MappingKernelManager
+from ..services.sessions.sessionmanager import SessionManager
+
+from jupyter_client.kernelspec import KernelSpecManager
+from ..utils import url_path_join
+
+from traitlets import Instance, Unicode, Float, Bool, default, validate, TraitError
+from traitlets.config import SingletonConfigurable
+
+
+class GatewayClient(SingletonConfigurable):
+ """This class manages the configuration. It's its own singleton class so that we
+ can share these values across all objects. It also contains some helper methods
+ to build request arguments out of the various config options.
+
+ """
+
+ url = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The url of the Kernel or Enterprise Gateway server where
+ kernel specifications are defined and kernel management takes place.
+ If defined, this Notebook server acts as a proxy for all kernel
+ management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var)
+ """
+ )
+
+ url_env = 'JUPYTER_GATEWAY_URL'
+ @default('url')
+ def _url_default(self):
+ return os.environ.get(self.url_env)
+
+ @validate('url')
+ def _url_validate(self, proposal):
+ value = proposal['value']
+ # Ensure value, if present, starts with 'http'
+ if value is not None and len(value) > 0:
+ if not str(value).lower().startswith('http'):
+ raise TraitError("GatewayClient url must start with 'http': '%r'" % value)
+ return value
+
+ ws_url = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value
+ will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var)
+ """
+ )
+
+ ws_url_env = 'JUPYTER_GATEWAY_WS_URL'
+ @default('ws_url')
+ def _ws_url_default(self):
+ default_value = os.environ.get(self.ws_url_env)
+ if default_value is None:
+ if self.gateway_enabled:
+ default_value = self.url.lower().replace('http', 'ws')
+ return default_value
+
+ @validate('ws_url')
+ def _ws_url_validate(self, proposal):
+ value = proposal['value']
+ # Ensure value, if present, starts with 'ws'
+ if value is not None and len(value) > 0:
+ if not str(value).lower().startswith('ws'):
+ raise TraitError("GatewayClient ws_url must start with 'ws': '%r'" % value)
+ return value
+
+ kernels_endpoint_default_value = '/api/kernels'
+ kernels_endpoint_env = 'JUPYTER_GATEWAY_KERNELS_ENDPOINT'
+ kernels_endpoint = Unicode(default_value=kernels_endpoint_default_value, config=True,
+ help="""The gateway API endpoint for accessing kernel resources (JUPYTER_GATEWAY_KERNELS_ENDPOINT env var)""")
+
+ @default('kernels_endpoint')
+ def _kernels_endpoint_default(self):
+ return os.environ.get(self.kernels_endpoint_env, self.kernels_endpoint_default_value)
+
+ kernelspecs_endpoint_default_value = '/api/kernelspecs'
+ kernelspecs_endpoint_env = 'JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT'
+ kernelspecs_endpoint = Unicode(default_value=kernelspecs_endpoint_default_value, config=True,
+ help="""The gateway API endpoint for accessing kernelspecs (JUPYTER_GATEWAY_KERNELSPECS_ENDPOINT env var)""")
+
+ @default('kernelspecs_endpoint')
+ def _kernelspecs_endpoint_default(self):
+ return os.environ.get(self.kernelspecs_endpoint_env, self.kernelspecs_endpoint_default_value)
+
+ connect_timeout_default_value = 20.0
+ connect_timeout_env = 'JUPYTER_GATEWAY_CONNECT_TIMEOUT'
+ connect_timeout = Float(default_value=connect_timeout_default_value, config=True,
+ help="""The time allowed for HTTP connection establishment with the Gateway server.
+ (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""")
+
+ @default('connect_timeout')
+ def connect_timeout_default(self):
+ return float(os.environ.get('JUPYTER_GATEWAY_CONNECT_TIMEOUT', self.connect_timeout_default_value))
+
+ request_timeout_default_value = 20.0
+ request_timeout_env = 'JUPYTER_GATEWAY_REQUEST_TIMEOUT'
+ request_timeout = Float(default_value=request_timeout_default_value, config=True,
+ help="""The time allowed for HTTP request completion. (JUPYTER_GATEWAY_REQUEST_TIMEOUT env var)""")
+
+ @default('request_timeout')
+ def request_timeout_default(self):
+ return float(os.environ.get('JUPYTER_GATEWAY_REQUEST_TIMEOUT', self.request_timeout_default_value))
+
+ client_key = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The filename for client SSL key, if any. (JUPYTER_GATEWAY_CLIENT_KEY env var)
+ """
+ )
+ client_key_env = 'JUPYTER_GATEWAY_CLIENT_KEY'
+
+ @default('client_key')
+ def _client_key_default(self):
+ return os.environ.get(self.client_key_env)
+
+ client_cert = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The filename for client SSL certificate, if any. (JUPYTER_GATEWAY_CLIENT_CERT env var)
+ """
+ )
+ client_cert_env = 'JUPYTER_GATEWAY_CLIENT_CERT'
+
+ @default('client_cert')
+ def _client_cert_default(self):
+ return os.environ.get(self.client_cert_env)
+
+ ca_certs = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The filename of CA certificates or None to use defaults. (JUPYTER_GATEWAY_CA_CERTS env var)
+ """
+ )
+ ca_certs_env = 'JUPYTER_GATEWAY_CA_CERTS'
+
+ @default('ca_certs')
+ def _ca_certs_default(self):
+ return os.environ.get(self.ca_certs_env)
+
+ http_user = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The username for HTTP authentication. (JUPYTER_GATEWAY_HTTP_USER env var)
+ """
+ )
+ http_user_env = 'JUPYTER_GATEWAY_HTTP_USER'
+
+ @default('http_user')
+ def _http_user_default(self):
+ return os.environ.get(self.http_user_env)
+
+ http_pwd = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The password for HTTP authentication. (JUPYTER_GATEWAY_HTTP_PWD env var)
+ """
+ )
+ http_pwd_env = 'JUPYTER_GATEWAY_HTTP_PWD'
+
+ @default('http_pwd')
+ def _http_pwd_default(self):
+ return os.environ.get(self.http_pwd_env)
+
+ headers_default_value = '{}'
+ headers_env = 'JUPYTER_GATEWAY_HEADERS'
+ headers = Unicode(default_value=headers_default_value, allow_none=True,config=True,
+ help="""Additional HTTP headers to pass on the request. This value will be converted to a dict.
+ (JUPYTER_GATEWAY_HEADERS env var)
+ """
+ )
+
+ @default('headers')
+ def _headers_default(self):
+ return os.environ.get(self.headers_env, self.headers_default_value)
+
+ auth_token = Unicode(default_value=None, allow_none=True, config=True,
+ help="""The authorization token used in the HTTP headers. (JUPYTER_GATEWAY_AUTH_TOKEN env var)
+ """
+ )
+ auth_token_env = 'JUPYTER_GATEWAY_AUTH_TOKEN'
+
+ @default('auth_token')
+ def _auth_token_default(self):
+ return os.environ.get(self.auth_token_env)
+
+ validate_cert_default_value = True
+ validate_cert_env = 'JUPYTER_GATEWAY_VALIDATE_CERT'
+ validate_cert = Bool(default_value=validate_cert_default_value, config=True,
+ help="""For HTTPS requests, determines if server's certificate should be validated or not.
+ (JUPYTER_GATEWAY_VALIDATE_CERT env var)"""
+ )
+
+ @default('validate_cert')
+ def validate_cert_default(self):
+ return bool(os.environ.get(self.validate_cert_env, str(self.validate_cert_default_value)) not in ['no', 'false'])
+
+ def __init__(self, **kwargs):
+ super(GatewayClient, self).__init__(**kwargs)
+ self._static_args = {} # initialized on first use
+
+ env_whitelist_default_value = ''
+ env_whitelist_env = 'JUPYTER_GATEWAY_ENV_WHITELIST'
+ env_whitelist = Unicode(default_value=env_whitelist_default_value, config=True,
+ help="""A comma-separated list of environment variable names that will be included, along with
+ their values, in the kernel startup request. The corresponding `env_whitelist` configuration
+ value must also be set on the Gateway server - since that configuration value indicates which
+ environmental values to make available to the kernel. (JUPYTER_GATEWAY_ENV_WHITELIST env var)""")
+
+ @default('env_whitelist')
+ def _env_whitelist_default(self):
+ return os.environ.get(self.env_whitelist_env, self.env_whitelist_default_value)
+
+ @property
+ def gateway_enabled(self):
+ return bool(self.url is not None and len(self.url) > 0)
+
+ def init_static_args(self):
+ """Initialize arguments used on every request. Since these are static values, we'll
+ perform this operation once.
+
+ """
+ self._static_args['headers'] = json.loads(self.headers)
+ self._static_args['headers'].update({'Authorization': 'token {}'.format(self.auth_token)})
+ self._static_args['connect_timeout'] = self.connect_timeout
+ self._static_args['request_timeout'] = self.request_timeout
+ self._static_args['validate_cert'] = self.validate_cert
+ if self.client_cert:
+ self._static_args['client_cert'] = self.client_cert
+ self._static_args['client_key'] = self.client_key
+ if self.ca_certs:
+ self._static_args['ca_certs'] = self.ca_certs
+ if self.http_user:
+ self._static_args['auth_username'] = self.http_user
+ if self.http_pwd:
+ self._static_args['auth_password'] = self.http_pwd
+
+ def load_connection_args(self, **kwargs):
+ """Merges the static args relative to the connection, with the given keyword arguments. If statics
+ have yet to be initialized, we'll do that here.
+
+ """
+ if len(self._static_args) == 0:
+ self.init_static_args()
+
+ kwargs.update(self._static_args)
+ return kwargs
+
+
+@gen.coroutine
+def gateway_request(endpoint, **kwargs):
+ """Make an async request to kernel gateway endpoint, returns a response """
+ client = AsyncHTTPClient()
+ kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+ try:
+ response = yield client.fetch(endpoint, **kwargs)
+ # Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
+ # or the server is not running.
+ # NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes
+ # of the tree view.
+ except ConnectionRefusedError:
+ raise web.HTTPError(503, "Connection refused from Gateway server url '{}'. "
+ "Check to be sure the Gateway instance is running.".format(GatewayClient.instance().url))
+ except HTTPTimeoutError:
+ # This can occur if the host is valid (e.g., foo.com) but there's nothing there.
+ raise web.HTTPError(504, "Timeout error attempting to connect to Gateway server url '{}'. " \
+ "Ensure gateway url is valid and the Gateway instance is running.".format(
+ GatewayClient.instance().url))
+ except gaierror as e:
+ raise web.HTTPError(404, "The Gateway server specified in the gateway_url '{}' doesn't appear to be valid. "
+ "Ensure gateway url is valid and the Gateway instance is running.".format(
+ GatewayClient.instance().url))
+
+ raise gen.Return(response)
+
+
+class GatewayKernelManager(MappingKernelManager):
+ """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway."""
+
+ # We'll maintain our own set of kernel ids
+ _kernels = {}
+
+ def __init__(self, **kwargs):
+ super(GatewayKernelManager, self).__init__(**kwargs)
+ self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernels_endpoint)
+
+ def __contains__(self, kernel_id):
+ return kernel_id in self._kernels
+
+ def remove_kernel(self, kernel_id):
+ """Complete override since we want to be more tolerant of missing keys """
+ try:
+ return self._kernels.pop(kernel_id)
+ except KeyError:
+ pass
+
+ def _get_kernel_endpoint_url(self, kernel_id=None):
+ """Builds a url for the kernels endpoint
+
+ Parameters
+ ----------
+ kernel_id: kernel UUID (optional)
+ """
+ if kernel_id:
+ return url_path_join(self.base_endpoint, url_escape(str(kernel_id)))
+
+ return self.base_endpoint
+
+ @gen.coroutine
+ def start_kernel(self, kernel_id=None, path=None, **kwargs):
+ """Start a kernel for a session and return its kernel_id.
+
+ Parameters
+ ----------
+ kernel_id : uuid
+ The uuid to associate the new kernel with. If this
+ is not None, this kernel will be persistent whenever it is
+ requested.
+ path : API path
+ The API path (unicode, '/' delimited) for the cwd.
+ Will be transformed to an OS path relative to root_dir.
+ """
+ self.log.info('Request start kernel: kernel_id=%s, path="%s"', kernel_id, path)
+
+ if kernel_id is None:
+ if path is not None:
+ kwargs['cwd'] = self.cwd_for_path(path)
+ kernel_name = kwargs.get('kernel_name', 'python3')
+ kernel_url = self._get_kernel_endpoint_url()
+ self.log.debug("Request new kernel at: %s" % kernel_url)
+
+ # Let KERNEL_USERNAME take precedent over http_user config option.
+ if os.environ.get('KERNEL_USERNAME') is None and GatewayClient.instance().http_user:
+ os.environ['KERNEL_USERNAME'] = GatewayClient.instance().http_user
+
+ kernel_env = {k: v for (k, v) in dict(os.environ).items() if k.startswith('KERNEL_')
+ or k in GatewayClient.instance().env_whitelist.split(",")}
+ json_body = json_encode({'name': kernel_name, 'env': kernel_env})
+
+ response = yield gateway_request(kernel_url, method='POST', body=json_body)
+ kernel = json_decode(response.body)
+ kernel_id = kernel['id']
+ self.log.info("Kernel started: %s" % kernel_id)
+ self.log.debug("Kernel args: %r" % kwargs)
+ else:
+ kernel = yield self.get_kernel(kernel_id)
+ kernel_id = kernel['id']
+ self.log.info("Using existing kernel: %s" % kernel_id)
+
+ self._kernels[kernel_id] = kernel
+ raise gen.Return(kernel_id)
+
+ @gen.coroutine
+ def get_kernel(self, kernel_id=None, **kwargs):
+ """Get kernel for kernel_id.
+
+ Parameters
+ ----------
+ kernel_id : uuid
+ The uuid of the kernel.
+ """
+ kernel_url = self._get_kernel_endpoint_url(kernel_id)
+ self.log.debug("Request kernel at: %s" % kernel_url)
+ try:
+ response = yield gateway_request(kernel_url, method='GET')
+ except HTTPError as error:
+ if error.code == 404:
+ self.log.warn("Kernel not found at: %s" % kernel_url)
+ self.remove_kernel(kernel_id)
+ kernel = None
+ else:
+ raise
+ else:
+ kernel = json_decode(response.body)
+ self._kernels[kernel_id] = kernel
+ self.log.debug("Kernel retrieved: %s" % kernel)
+ raise gen.Return(kernel)
+
+ @gen.coroutine
+ def kernel_model(self, kernel_id):
+ """Return a dictionary of kernel information described in the
+ JSON standard model.
+
+ Parameters
+ ----------
+ kernel_id : uuid
+ The uuid of the kernel.
+ """
+ self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id)
+ model = yield self.get_kernel(kernel_id)
+ raise gen.Return(model)
+
+ @gen.coroutine
+ def list_kernels(self, **kwargs):
+ """Get a list of kernels."""
+ kernel_url = self._get_kernel_endpoint_url()
+ self.log.debug("Request list kernels: %s", kernel_url)
+ response = yield gateway_request(kernel_url, method='GET')
+ kernels = json_decode(response.body)
+ self._kernels = {x['id']:x for x in kernels}
+ raise gen.Return(kernels)
+
+ @gen.coroutine
+ def shutdown_kernel(self, kernel_id, now=False, restart=False):
+ """Shutdown a kernel by its kernel uuid.
+
+ Parameters
+ ==========
+ kernel_id : uuid
+ The id of the kernel to shutdown.
+ """
+ kernel_url = self._get_kernel_endpoint_url(kernel_id)
+ self.log.debug("Request shutdown kernel at: %s", kernel_url)
+ response = yield gateway_request(kernel_url, method='DELETE')
+ self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason)
+ self.remove_kernel(kernel_id)
+
+ @gen.coroutine
+ def restart_kernel(self, kernel_id, now=False, **kwargs):
+ """Restart a kernel by its kernel uuid.
+
+ Parameters
+ ==========
+ kernel_id : uuid
+ The id of the kernel to restart.
+ """
+ kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart'
+ self.log.debug("Request restart kernel at: %s", kernel_url)
+ response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
+ self.log.debug("Restart kernel response: %d %s", response.code, response.reason)
+
+ @gen.coroutine
+ def interrupt_kernel(self, kernel_id, **kwargs):
+ """Interrupt a kernel by its kernel uuid.
+
+ Parameters
+ ==========
+ kernel_id : uuid
+ The id of the kernel to interrupt.
+ """
+ kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt'
+ self.log.debug("Request interrupt kernel at: %s", kernel_url)
+ response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
+ self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason)
+
+ def shutdown_all(self, now=False):
+ """Shutdown all kernels."""
+ # Note: We have to make this sync because the NotebookApp does not wait for async.
+ shutdown_kernels = []
+ kwargs = {'method': 'DELETE'}
+ kwargs = GatewayClient.instance().load_connection_args(**kwargs)
+ client = HTTPClient()
+ for kernel_id in self._kernels.keys():
+ kernel_url = self._get_kernel_endpoint_url(kernel_id)
+ self.log.debug("Request delete kernel at: %s", kernel_url)
+ try:
+ response = client.fetch(kernel_url, **kwargs)
+ except HTTPError:
+ pass
+ else:
+ self.log.debug("Delete kernel response: %d %s", response.code, response.reason)
+ shutdown_kernels.append(kernel_id) # avoid changing dict size during iteration
+ client.close()
+ for kernel_id in shutdown_kernels:
+ self.remove_kernel(kernel_id)
+
+
+class GatewayKernelSpecManager(KernelSpecManager):
+
+ def __init__(self, **kwargs):
+ super(GatewayKernelSpecManager, self).__init__(**kwargs)
+ self.base_endpoint = url_path_join(GatewayClient.instance().url, GatewayClient.instance().kernelspecs_endpoint)
+
+ def _get_kernelspecs_endpoint_url(self, kernel_name=None):
+ """Builds a url for the kernels endpoint
+
+ Parameters
+ ----------
+ kernel_name: kernel name (optional)
+ """
+ if kernel_name:
+ return url_path_join(self.base_endpoint, url_escape(kernel_name))
+
+ return self.base_endpoint
+
+ @gen.coroutine
+ def get_all_specs(self):
+ fetched_kspecs = yield self.list_kernel_specs()
+
+ # get the default kernel name and compare to that of this server.
+ # If different log a warning and reset the default. However, the
+ # caller of this method will still return this server's value until
+ # the next fetch of kernelspecs - at which time they'll match.
+ km = self.parent.kernel_manager
+ remote_default_kernel_name = fetched_kspecs.get('default')
+ if remote_default_kernel_name != km.default_kernel_name:
+ self.log.info("Default kernel name on Gateway server ({gateway_default}) differs from "
+ "Notebook server ({notebook_default}). Updating to Gateway server's value.".
+ format(gateway_default=remote_default_kernel_name,
+ notebook_default=km.default_kernel_name))
+ km.default_kernel_name = remote_default_kernel_name
+
+ # gateway doesn't support resources (requires transfer for use by NB client)
+ # so add `resource_dir` to each kernelspec and value of 'not supported in gateway mode'
+ remote_kspecs = fetched_kspecs.get('kernelspecs')
+ for kernel_name, kspec_info in remote_kspecs.items():
+ if not kspec_info.get('resource_dir'):
+ kspec_info['resource_dir'] = 'not supported in gateway mode'
+ remote_kspecs[kernel_name].update(kspec_info)
+
+ raise gen.Return(remote_kspecs)
+
+ @gen.coroutine
+ def list_kernel_specs(self):
+ """Get a list of kernel specs."""
+ kernel_spec_url = self._get_kernelspecs_endpoint_url()
+ self.log.debug("Request list kernel specs at: %s", kernel_spec_url)
+ response = yield gateway_request(kernel_spec_url, method='GET')
+ kernel_specs = json_decode(response.body)
+ raise gen.Return(kernel_specs)
+
+ @gen.coroutine
+ def get_kernel_spec(self, kernel_name, **kwargs):
+ """Get kernel spec for kernel_name.
+
+ Parameters
+ ----------
+ kernel_name : str
+ The name of the kernel.
+ """
+ kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name))
+ self.log.debug("Request kernel spec at: %s" % kernel_spec_url)
+ try:
+ response = yield gateway_request(kernel_spec_url, method='GET')
+ except HTTPError as error:
+ if error.code == 404:
+ # Convert not found to KeyError since that's what the Notebook handler expects
+ # message is not used, but might as well make it useful for troubleshooting
+ raise KeyError('kernelspec {kernel_name} not found on Gateway server at: {gateway_url}'.
+ format(kernel_name=kernel_name, gateway_url=GatewayClient.instance().url))
+ else:
+ raise
+ else:
+ kernel_spec = json_decode(response.body)
+ # Convert to instance of Kernelspec
+ kspec_instance = self.kernel_spec_class(resource_dir=u'', **kernel_spec['spec'])
+ raise gen.Return(kspec_instance)
+
+
+class GatewaySessionManager(SessionManager):
+ kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager')
+
+ @gen.coroutine
+ def kernel_culled(self, kernel_id):
+ """Checks if the kernel is still considered alive and returns true if its not found. """
+ kernel = yield self.kernel_manager.get_kernel(kernel_id)
+ raise gen.Return(kernel is None)
diff --git a/notebook/notebookapp.py b/notebook/notebookapp.py
index 05e44cf29f..2639b4faa8 100755
--- a/notebook/notebookapp.py
+++ b/notebook/notebookapp.py
@@ -84,6 +84,7 @@
from .services.contents.filemanager import FileContentsManager
from .services.contents.largefilemanager import LargeFileManager
from .services.sessions.sessionmanager import SessionManager
+from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient
from .auth.login import LoginHandler
from .auth.logout import LogoutHandler
@@ -96,7 +97,7 @@
)
from jupyter_core.paths import jupyter_config_path
from jupyter_client import KernelManager
-from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME
+from jupyter_client.kernelspec import KernelSpecManager
from jupyter_client.session import Session
from nbformat.sign import NotebookNotary
from traitlets import (
@@ -144,6 +145,7 @@ def load_handlers(name):
# The Tornado web application
#-----------------------------------------------------------------------------
+
class NotebookWebApplication(web.Application):
def __init__(self, jupyter_app, kernel_manager, contents_manager,
@@ -151,7 +153,6 @@ def __init__(self, jupyter_app, kernel_manager, contents_manager,
config_manager, extra_services, log,
base_url, default_url, settings_overrides, jinja_env_options):
-
settings = self.init_settings(
jupyter_app, kernel_manager, contents_manager,
session_manager, kernel_spec_manager, config_manager,
@@ -305,15 +306,27 @@ def init_handlers(self, settings):
handlers.extend(load_handlers('notebook.edit.handlers'))
handlers.extend(load_handlers('notebook.services.api.handlers'))
handlers.extend(load_handlers('notebook.services.config.handlers'))
- handlers.extend(load_handlers('notebook.services.kernels.handlers'))
handlers.extend(load_handlers('notebook.services.contents.handlers'))
handlers.extend(load_handlers('notebook.services.sessions.handlers'))
handlers.extend(load_handlers('notebook.services.nbconvert.handlers'))
- handlers.extend(load_handlers('notebook.services.kernelspecs.handlers'))
handlers.extend(load_handlers('notebook.services.security.handlers'))
handlers.extend(load_handlers('notebook.services.shutdown'))
+ handlers.extend(load_handlers('notebook.services.kernels.handlers'))
+ handlers.extend(load_handlers('notebook.services.kernelspecs.handlers'))
+
handlers.extend(settings['contents_manager'].get_extra_handlers())
+ # If gateway mode is enabled, replace appropriate handlers to perform redirection
+ if GatewayClient.instance().gateway_enabled:
+ # for each handler required for gateway, locate its pattern
+ # in the current list and replace that entry...
+ gateway_handlers = load_handlers('notebook.gateway.handlers')
+ for i, gwh in enumerate(gateway_handlers):
+ for j, h in enumerate(handlers):
+ if gwh[0] == h[0]:
+ handlers[j] = (gwh[0], gwh[1])
+ break
+
handlers.append(
(r"/nbextensions/(.*)", FileFindHandler, {
'path': settings['nbextensions_path'],
@@ -547,6 +560,7 @@ def start(self):
'notebook-dir': 'NotebookApp.notebook_dir',
'browser': 'NotebookApp.browser',
'pylab': 'NotebookApp.pylab',
+ 'gateway-url': 'GatewayClient.url',
})
#-----------------------------------------------------------------------------
@@ -565,9 +579,9 @@ class NotebookApp(JupyterApp):
flags = flags
classes = [
- KernelManager, Session, MappingKernelManager,
+ KernelManager, Session, MappingKernelManager, KernelSpecManager,
ContentsManager, FileContentsManager, NotebookNotary,
- KernelSpecManager,
+ GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient,
]
flags = Dict(flags)
aliases = Dict(aliases)
@@ -1316,6 +1330,16 @@ def parse_command_line(self, argv=None):
self.update_config(c)
def init_configurables(self):
+
+ # If gateway server is configured, replace appropriate managers to perform redirection. To make
+ # this determination, instantiate the GatewayClient config singleton.
+ self.gateway_config = GatewayClient.instance(parent=self)
+
+ if self.gateway_config.gateway_enabled:
+ self.kernel_manager_class = 'notebook.gateway.managers.GatewayKernelManager'
+ self.session_manager_class = 'notebook.gateway.managers.GatewaySessionManager'
+ self.kernel_spec_manager_class = 'notebook.gateway.managers.GatewayKernelSpecManager'
+
self.kernel_spec_manager = self.kernel_spec_manager_class(
parent=self,
)
@@ -1661,6 +1685,8 @@ def notebook_info(self, kernel_count=True):
info += "\n"
# Format the info so that the URL fits on a single line in 80 char display
info += _("The Jupyter Notebook is running at:\n%s") % self.display_url
+ if self.gateway_config.gateway_enabled:
+ info += _("\nKernels will be managed by the Gateway server running at:\n%s") % self.gateway_config.url
return info
def server_info(self):
diff --git a/notebook/services/kernels/handlers.py b/notebook/services/kernels/handlers.py
index cfef2a4a0e..897fa51db2 100644
--- a/notebook/services/kernels/handlers.py
+++ b/notebook/services/kernels/handlers.py
@@ -45,7 +45,7 @@ def post(self):
model.setdefault('name', km.default_kernel_name)
kernel_id = yield gen.maybe_future(km.start_kernel(kernel_name=model['name']))
- model = km.kernel_model(kernel_id)
+ model = yield gen.maybe_future(km.kernel_model(kernel_id))
location = url_path_join(self.base_url, 'api', 'kernels', url_escape(kernel_id))
self.set_header('Location', location)
self.set_status(201)
@@ -57,7 +57,6 @@ class KernelHandler(APIHandler):
@web.authenticated
def get(self, kernel_id):
km = self.kernel_manager
- km._check_kernel_id(kernel_id)
model = km.kernel_model(kernel_id)
self.finish(json.dumps(model, default=date_default))
@@ -87,7 +86,7 @@ def post(self, kernel_id, action):
self.log.error("Exception restarting kernel", exc_info=True)
self.set_status(500)
else:
- model = km.kernel_model(kernel_id)
+ model = yield gen.maybe_future(km.kernel_model(kernel_id))
self.write(json.dumps(model, default=date_default))
self.finish()
diff --git a/notebook/services/kernelspecs/handlers.py b/notebook/services/kernelspecs/handlers.py
index d272db2f71..c0157e4c57 100644
--- a/notebook/services/kernelspecs/handlers.py
+++ b/notebook/services/kernelspecs/handlers.py
@@ -11,7 +11,7 @@
import os
pjoin = os.path.join
-from tornado import web
+from tornado import web, gen
from ...base.handlers import APIHandler
from ...utils import url_path_join, url_unescape
@@ -48,13 +48,15 @@ def kernelspec_model(handler, name, spec_dict, resource_dir):
class MainKernelSpecHandler(APIHandler):
@web.authenticated
+ @gen.coroutine
def get(self):
ksm = self.kernel_spec_manager
km = self.kernel_manager
model = {}
model['default'] = km.default_kernel_name
model['kernelspecs'] = specs = {}
- for kernel_name, kernel_info in ksm.get_all_specs().items():
+ kspecs = yield gen.maybe_future(ksm.get_all_specs())
+ for kernel_name, kernel_info in kspecs.items():
try:
d = kernelspec_model(self, kernel_name, kernel_info['spec'],
kernel_info['resource_dir'])
@@ -69,11 +71,12 @@ def get(self):
class KernelSpecHandler(APIHandler):
@web.authenticated
+ @gen.coroutine
def get(self, kernel_name):
ksm = self.kernel_spec_manager
kernel_name = url_unescape(kernel_name)
try:
- spec = ksm.get_kernel_spec(kernel_name)
+ spec = yield gen.maybe_future(ksm.get_kernel_spec(kernel_name))
except KeyError:
raise web.HTTPError(404, u'Kernel spec %s not found' % kernel_name)
model = kernelspec_model(self, kernel_name, spec.to_dict(), spec.resource_dir)
diff --git a/notebook/services/sessions/sessionmanager.py b/notebook/services/sessions/sessionmanager.py
index ee70eb0810..4497cfbc33 100644
--- a/notebook/services/sessions/sessionmanager.py
+++ b/notebook/services/sessions/sessionmanager.py
@@ -56,21 +56,22 @@ def __del__(self):
"""Close connection once SessionManager closes"""
self.close()
+ @gen.coroutine
def session_exists(self, path):
"""Check to see if the session of a given name exists"""
+ exists = False
self.cursor.execute("SELECT * FROM session WHERE path=?", (path,))
row = self.cursor.fetchone()
- if row is None:
- return False
- else:
+ if row is not None:
# Note, although we found a row for the session, the associated kernel may have
# been culled or died unexpectedly. If that's the case, we should delete the
# row, thereby terminating the session. This can be done via a call to
# row_to_model that tolerates that condition. If row_to_model returns None,
# we'll return false, since, at that point, the session doesn't exist anyway.
- if self.row_to_model(row, tolerate_culled=True) is None:
- return False
- return True
+ model = yield gen.maybe_future(self.row_to_model(row, tolerate_culled=True))
+ if model is not None:
+ exists = True
+ raise gen.Return(exists)
def new_session_id(self):
"Create a uuid for a new session"
@@ -101,6 +102,7 @@ def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
# py2-compat
raise gen.Return(kernel_id)
+ @gen.coroutine
def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None):
"""Saves the items for the session with the given session_id
@@ -129,8 +131,10 @@ def save_session(self, session_id, path=None, name=None, type=None, kernel_id=No
self.cursor.execute("INSERT INTO session VALUES (?,?,?,?,?)",
(session_id, path, name, type, kernel_id)
)
- return self.get_session(session_id=session_id)
+ result = yield gen.maybe_future(self.get_session(session_id=session_id))
+ raise gen.Return(result)
+ @gen.coroutine
def get_session(self, **kwargs):
"""Returns the model for a particular session.
@@ -174,8 +178,10 @@ def get_session(self, **kwargs):
raise web.HTTPError(404, u'Session not found: %s' % (', '.join(q)))
- return self.row_to_model(row)
+ model = yield gen.maybe_future(self.row_to_model(row))
+ raise gen.Return(model)
+ @gen.coroutine
def update_session(self, session_id, **kwargs):
"""Updates the values in the session database.
@@ -191,7 +197,7 @@ def update_session(self, session_id, **kwargs):
and the value replaces the current value in the session
with session_id.
"""
- self.get_session(session_id=session_id)
+ yield gen.maybe_future(self.get_session(session_id=session_id))
if not kwargs:
# no changes
@@ -205,9 +211,15 @@ def update_session(self, session_id, **kwargs):
query = "UPDATE session SET %s WHERE session_id=?" % (', '.join(sets))
self.cursor.execute(query, list(kwargs.values()) + [session_id])
+ def kernel_culled(self, kernel_id):
+ """Checks if the kernel is still considered alive and returns true if its not found. """
+ return kernel_id not in self.kernel_manager
+
+ @gen.coroutine
def row_to_model(self, row, tolerate_culled=False):
"""Takes sqlite database session row and turns it into a dictionary"""
- if row['kernel_id'] not in self.kernel_manager:
+ kernel_culled = yield gen.maybe_future(self.kernel_culled(row['kernel_id']))
+ if kernel_culled:
# The kernel was culled or died without deleting the session.
# We can't use delete_session here because that tries to find
# and shut down the kernel - so we'll delete the row directly.
@@ -222,21 +234,23 @@ def row_to_model(self, row, tolerate_culled=False):
format(kernel_id=row['kernel_id'],session_id=row['session_id'])
if tolerate_culled:
self.log.warning(msg + " Continuing...")
- return None
+ raise gen.Return(None)
raise KeyError(msg)
+ kernel_model = yield gen.maybe_future(self.kernel_manager.kernel_model(row['kernel_id']))
model = {
'id': row['session_id'],
'path': row['path'],
'name': row['name'],
'type': row['type'],
- 'kernel': self.kernel_manager.kernel_model(row['kernel_id'])
+ 'kernel': kernel_model
}
if row['type'] == 'notebook':
# Provide the deprecated API.
model['notebook'] = {'path': row['path'], 'name': row['name']}
- return model
+ raise gen.Return(model)
+ @gen.coroutine
def list_sessions(self):
"""Returns a list of dictionaries containing all the information from
the session database"""
@@ -246,14 +260,15 @@ def list_sessions(self):
# which messes up the cursor if we're iterating over rows.
for row in c.fetchall():
try:
- result.append(self.row_to_model(row))
+ model = yield gen.maybe_future(self.row_to_model(row))
+ result.append(model)
except KeyError:
pass
- return result
+ raise gen.Return(result)
@gen.coroutine
def delete_session(self, session_id):
"""Deletes the row in the session database with given session_id"""
- session = self.get_session(session_id=session_id)
+ session = yield gen.maybe_future(self.get_session(session_id=session_id))
yield gen.maybe_future(self.kernel_manager.shutdown_kernel(session['kernel']['id']))
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
diff --git a/notebook/services/sessions/tests/test_sessionmanager.py b/notebook/services/sessions/tests/test_sessionmanager.py
index 96847a868a..97331ebf9b 100644
--- a/notebook/services/sessions/tests/test_sessionmanager.py
+++ b/notebook/services/sessions/tests/test_sessionmanager.py
@@ -62,11 +62,11 @@ def co_add():
def create_session(self, **kwargs):
return self.create_sessions(kwargs)[0]
-
+
def test_get_session(self):
sm = self.sm
session_id = self.create_session(path='/path/to/test.ipynb', kernel_name='bar')['id']
- model = sm.get_session(session_id=session_id)
+ model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id))
expected = {'id':session_id,
'path': u'/path/to/test.ipynb',
'notebook': {'path': u'/path/to/test.ipynb', 'name': None},
@@ -86,7 +86,8 @@ def test_bad_get_session(self):
sm = self.sm
session_id = self.create_session(path='/path/to/test.ipynb',
kernel_name='foo')['id']
- self.assertRaises(TypeError, sm.get_session, bad_id=session_id) # Bad keyword
+ with self.assertRaises(TypeError):
+ self.loop.run_sync(lambda: sm.get_session(bad_id=session_id)) # Bad keyword
def test_get_session_dead_kernel(self):
sm = self.sm
@@ -94,9 +95,9 @@ def test_get_session_dead_kernel(self):
# kill the kernel
sm.kernel_manager.shutdown_kernel(session['kernel']['id'])
with self.assertRaises(KeyError):
- sm.get_session(session_id=session['id'])
+ self.loop.run_sync(lambda: sm.get_session(session_id=session['id']))
# no sessions left
- listed = sm.list_sessions()
+ listed = self.loop.run_sync(lambda: sm.list_sessions())
self.assertEqual(listed, [])
def test_list_sessions(self):
@@ -107,7 +108,7 @@ def test_list_sessions(self):
dict(path='/path/to/3', name='foo', type='console', kernel_name='python'),
)
- sessions = sm.list_sessions()
+ sessions = self.loop.run_sync(lambda: sm.list_sessions())
expected = [
{
'id':sessions[0]['id'],
@@ -158,7 +159,7 @@ def test_list_sessions_dead_kernel(self):
)
# kill one of the kernels
sm.kernel_manager.shutdown_kernel(sessions[0]['kernel']['id'])
- listed = sm.list_sessions()
+ listed = self.loop.run_sync(lambda: sm.list_sessions())
expected = [
{
'id': sessions[1]['id'],
@@ -181,8 +182,8 @@ def test_update_session(self):
sm = self.sm
session_id = self.create_session(path='/path/to/test.ipynb',
kernel_name='julia')['id']
- sm.update_session(session_id, path='/path/to/new_name.ipynb')
- model = sm.get_session(session_id=session_id)
+ self.loop.run_sync(lambda: sm.update_session(session_id, path='/path/to/new_name.ipynb'))
+ model = self.loop.run_sync(lambda: sm.get_session(session_id=session_id))
expected = {'id':session_id,
'path': u'/path/to/new_name.ipynb',
'type': 'notebook',
@@ -203,7 +204,8 @@ def test_bad_update_session(self):
sm = self.sm
session_id = self.create_session(path='/path/to/test.ipynb',
kernel_name='ir')['id']
- self.assertRaises(TypeError, sm.update_session, session_id=session_id, bad_kw='test.ipynb') # Bad keyword
+ with self.assertRaises(TypeError):
+ self.loop.run_sync(lambda: sm.update_session(session_id=session_id, bad_kw='test.ipynb')) # Bad keyword
def test_delete_session(self):
sm = self.sm
@@ -212,8 +214,8 @@ def test_delete_session(self):
dict(path='/path/to/2/test2.ipynb', kernel_name='python'),
dict(path='/path/to/3', name='foo', type='console', kernel_name='python'),
)
- sm.delete_session(sessions[1]['id'])
- new_sessions = sm.list_sessions()
+ self.loop.run_sync(lambda: sm.delete_session(sessions[1]['id']))
+ new_sessions = self.loop.run_sync(lambda: sm.list_sessions())
expected = [{
'id': sessions[0]['id'],
'path': u'/path/to/1/test1.ipynb',
diff --git a/notebook/tests/launchnotebook.py b/notebook/tests/launchnotebook.py
index 1b685df0ca..9e84a5964b 100644
--- a/notebook/tests/launchnotebook.py
+++ b/notebook/tests/launchnotebook.py
@@ -91,6 +91,22 @@ def request(cls, verb, path, **kwargs):
url_path_join(cls.base_url(), path),
**kwargs)
return response
+
+ @classmethod
+ def get_patch_env(cls):
+ return {
+ 'HOME': cls.home_dir,
+ 'PYTHONPATH': os.pathsep.join(sys.path),
+ 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'),
+ 'JUPYTER_NO_CONFIG': '1', # needed in the future
+ 'JUPYTER_CONFIG_DIR' : cls.config_dir,
+ 'JUPYTER_DATA_DIR' : cls.data_dir,
+ 'JUPYTER_RUNTIME_DIR': cls.runtime_dir,
+ }
+
+ @classmethod
+ def get_argv(cls):
+ return []
@classmethod
def setup_class(cls):
@@ -109,15 +125,7 @@ def tmp(*parts):
config_dir = cls.config_dir = tmp('config')
runtime_dir = cls.runtime_dir = tmp('runtime')
cls.notebook_dir = tmp('notebooks')
- cls.env_patch = patch.dict('os.environ', {
- 'HOME': cls.home_dir,
- 'PYTHONPATH': os.pathsep.join(sys.path),
- 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'),
- 'JUPYTER_NO_CONFIG': '1', # needed in the future
- 'JUPYTER_CONFIG_DIR' : config_dir,
- 'JUPYTER_DATA_DIR' : data_dir,
- 'JUPYTER_RUNTIME_DIR': runtime_dir,
- })
+ cls.env_patch = patch.dict('os.environ', cls.get_patch_env())
cls.env_patch.start()
cls.path_patch = patch.multiple(
jupyter_core.paths,
@@ -157,7 +165,7 @@ def start_thread():
# needs to be redone after initialize, which reconfigures logging
app.log.propagate = True
app.log.handlers = []
- app.initialize(argv=[])
+ app.initialize(argv=cls.get_argv())
app.log.propagate = True
app.log.handlers = []
loop = IOLoop.current()
diff --git a/notebook/tests/test_gateway.py b/notebook/tests/test_gateway.py
new file mode 100644
index 0000000000..ef3cd7ef56
--- /dev/null
+++ b/notebook/tests/test_gateway.py
@@ -0,0 +1,354 @@
+"""Test GatewayClient"""
+import os
+import json
+import uuid
+from datetime import datetime
+from tornado import gen
+from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError
+from traitlets.config import Config
+from .launchnotebook import NotebookTestBase
+from notebook.gateway.managers import GatewayClient
+
+try:
+ from unittest.mock import patch, Mock
+except ImportError:
+ from mock import patch, Mock # py2
+
+try:
+ from io import StringIO
+except ImportError:
+ import StringIO
+
+import nose.tools as nt
+
+
+def generate_kernelspec(name):
+ argv_stanza = ['python', '-m', 'ipykernel_launcher', '-f', '{connection_file}']
+ spec_stanza = {'spec': {'argv': argv_stanza, 'env': {}, 'display_name': name, 'language': 'python', 'interrupt_mode': 'signal', 'metadata': {}}}
+ kernelspec_stanza = {'name': name, 'spec': spec_stanza, 'resources': {}}
+ return kernelspec_stanza
+
+
+# We'll mock up two kernelspecs - kspec_foo and kspec_bar
+kernelspecs = {'default': 'kspec_foo', 'kernelspecs': {'kspec_foo': generate_kernelspec('kspec_foo'), 'kspec_bar': generate_kernelspec('kspec_bar')}}
+
+
+# maintain a dictionary of expected running kernels. Key = kernel_id, Value = model.
+running_kernels = dict()
+
+
+def generate_model(name):
+ """Generate a mocked kernel model. Caller is responsible for adding model to running_kernels dictionary."""
+ dt = datetime.utcnow().isoformat() + 'Z'
+ kernel_id = str(uuid.uuid4())
+ model = {'id': kernel_id, 'name': name, 'last_activity': str(dt), 'execution_state': 'idle', 'connections': 1}
+ return model
+
+
+@gen.coroutine
+def mock_gateway_request(url, **kwargs):
+ method = 'GET'
+ if kwargs['method']:
+ method = kwargs['method']
+
+ request = HTTPRequest(url=url, **kwargs)
+
+ endpoint = str(url)
+
+ # Fetch all kernelspecs
+ if endpoint.endswith('/api/kernelspecs') and method == 'GET':
+ response_buf = StringIO(json.dumps(kernelspecs))
+ response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf))
+ raise gen.Return(response)
+
+ # Fetch named kernelspec
+ if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET':
+ requested_kernelspec = endpoint.rpartition('/')[2]
+ kspecs = kernelspecs.get('kernelspecs')
+ if requested_kernelspec in kspecs:
+ response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec)))
+ response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf))
+ raise gen.Return(response)
+ else:
+ raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec)
+
+ # Create kernel
+ if endpoint.endswith('/api/kernels') and method == 'POST':
+ json_body = json.loads(kwargs['body'])
+ name = json_body.get('name')
+ env = json_body.get('env')
+ kspec_name = env.get('KERNEL_KSPEC_NAME')
+ nt.assert_equal(name, kspec_name) # Ensure that KERNEL_ env values get propagated
+ model = generate_model(name)
+ running_kernels[model.get('id')] = model # Register model as a running kernel
+ response_buf = StringIO(json.dumps(model))
+ response = yield gen.maybe_future(HTTPResponse(request, 201, buffer=response_buf))
+ raise gen.Return(response)
+
+ # Fetch list of running kernels
+ if endpoint.endswith('/api/kernels') and method == 'GET':
+ kernels = []
+ for kernel_id in running_kernels.keys():
+ model = running_kernels.get(kernel_id)
+ kernels.append(model)
+ response_buf = StringIO(json.dumps(kernels))
+ response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf))
+ raise gen.Return(response)
+
+ # Interrupt or restart existing kernel
+ if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST':
+ requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/')
+
+ if action == 'interrupt':
+ if requested_kernel_id in running_kernels:
+ response = yield gen.maybe_future(HTTPResponse(request, 204))
+ raise gen.Return(response)
+ else:
+ raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
+ elif action == 'restart':
+ if requested_kernel_id in running_kernels:
+ response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
+ response = yield gen.maybe_future(HTTPResponse(request, 204, buffer=response_buf))
+ raise gen.Return(response)
+ else:
+ raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
+ else:
+ raise HTTPError(404, message='Bad action detected: %s' % action)
+
+ # Shutdown existing kernel
+ if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE':
+ requested_kernel_id = endpoint.rpartition('/')[2]
+ running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set
+ response = yield gen.maybe_future(HTTPResponse(request, 204))
+ raise gen.Return(response)
+
+ # Fetch existing kernel
+ if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET':
+ requested_kernel_id = endpoint.rpartition('/')[2]
+ if requested_kernel_id in running_kernels:
+ response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
+ response = yield gen.maybe_future(HTTPResponse(request, 200, buffer=response_buf))
+ raise gen.Return(response)
+ else:
+ raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
+
+
+mocked_gateway = patch('notebook.gateway.managers.gateway_request', mock_gateway_request)
+
+
+class TestGateway(NotebookTestBase):
+
+ mock_gateway_url = 'http://mock-gateway-server:8889'
+ mock_http_user = 'alice'
+
+ @classmethod
+ def setup_class(cls):
+ GatewayClient.clear_instance()
+ super(TestGateway, cls).setup_class()
+
+ @classmethod
+ def teardown_class(cls):
+ GatewayClient.clear_instance()
+ super(TestGateway, cls).teardown_class()
+
+ @classmethod
+ def get_patch_env(cls):
+ test_env = super(TestGateway, cls).get_patch_env()
+ test_env.update({'JUPYTER_GATEWAY_URL': TestGateway.mock_gateway_url,
+ 'JUPYTER_GATEWAY_REQUEST_TIMEOUT': '44.4'})
+ return test_env
+
+ @classmethod
+ def get_argv(cls):
+ argv = super(TestGateway, cls).get_argv()
+ argv.extend(['--GatewayClient.connect_timeout=44.4', '--GatewayClient.http_user=' + TestGateway.mock_http_user])
+ return argv
+
+ def test_gateway_options(self):
+ nt.assert_equal(self.notebook.gateway_config.gateway_enabled, True)
+ nt.assert_equal(self.notebook.gateway_config.url, TestGateway.mock_gateway_url)
+ nt.assert_equal(self.notebook.gateway_config.http_user, TestGateway.mock_http_user)
+ nt.assert_equal(self.notebook.gateway_config.connect_timeout, self.notebook.gateway_config.connect_timeout)
+ nt.assert_equal(self.notebook.gateway_config.connect_timeout, 44.4)
+
+ def test_gateway_class_mappings(self):
+ # Ensure appropriate class mappings are in place.
+ nt.assert_equal(self.notebook.kernel_manager_class.__name__, 'GatewayKernelManager')
+ nt.assert_equal(self.notebook.session_manager_class.__name__, 'GatewaySessionManager')
+ nt.assert_equal(self.notebook.kernel_spec_manager_class.__name__, 'GatewayKernelSpecManager')
+
+ def test_gateway_get_kernelspecs(self):
+ # Validate that kernelspecs come from gateway.
+ with mocked_gateway:
+ response = self.request('GET', '/api/kernelspecs')
+ self.assertEqual(response.status_code, 200)
+ content = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ kspecs = content.get('kernelspecs')
+ self.assertEqual(len(kspecs), 2)
+ self.assertEqual(kspecs.get('kspec_bar').get('name'), 'kspec_bar')
+
+ def test_gateway_get_named_kernelspec(self):
+ # Validate that a specific kernelspec can be retrieved from gateway.
+ with mocked_gateway:
+ response = self.request('GET', '/api/kernelspecs/kspec_foo')
+ self.assertEqual(response.status_code, 200)
+ kspec_foo = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ self.assertEqual(kspec_foo.get('name'), 'kspec_foo')
+
+ response = self.request('GET', '/api/kernelspecs/no_such_spec')
+ self.assertEqual(response.status_code, 404)
+
+ def test_gateway_session_lifecycle(self):
+ # Validate session lifecycle functions; create and delete.
+
+ # create
+ session_id, kernel_id = self.create_session('kspec_foo')
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # interrupt
+ self.interrupt_kernel(kernel_id)
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # restart
+ self.restart_kernel(kernel_id)
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # delete
+ self.delete_session(session_id)
+ self.assertFalse(self.is_kernel_running(kernel_id))
+
+ def test_gateway_kernel_lifecycle(self):
+ # Validate kernel lifecycle functions; create, interrupt, restart and delete.
+
+ # create
+ kernel_id = self.create_kernel('kspec_bar')
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # interrupt
+ self.interrupt_kernel(kernel_id)
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # restart
+ self.restart_kernel(kernel_id)
+
+ # ensure kernel still considered running
+ self.assertTrue(self.is_kernel_running(kernel_id))
+
+ # delete
+ self.delete_kernel(kernel_id)
+ self.assertFalse(self.is_kernel_running(kernel_id))
+
+ def create_session(self, kernel_name):
+ """Creates a session for a kernel. The session is created against the notebook server
+ which then uses the gateway for kernel management.
+ """
+ with mocked_gateway:
+ nb_path = os.path.join(self.notebook_dir, 'testgw.ipynb')
+ kwargs = dict()
+ kwargs['json'] = {'path': nb_path, 'type': 'notebook', 'kernel': {'name': kernel_name}}
+
+ # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method
+ os.environ['KERNEL_KSPEC_NAME'] = kernel_name
+
+ # Create the kernel... (also tests get_kernel)
+ response = self.request('POST', '/api/sessions', **kwargs)
+ self.assertEqual(response.status_code, 201)
+ model = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ self.assertEqual(model.get('path'), nb_path)
+ kernel_id = model.get('kernel').get('id')
+ # ensure its in the running_kernels and name matches.
+ running_kernel = running_kernels.get(kernel_id)
+ self.assertEqual(kernel_id, running_kernel.get('id'))
+ self.assertEqual(model.get('kernel').get('name'), running_kernel.get('name'))
+ session_id = model.get('id')
+
+ # restore env
+ os.environ.pop('KERNEL_KSPEC_NAME')
+ return session_id, kernel_id
+
+ def delete_session(self, session_id):
+ """Deletes a session corresponding to the given session id.
+ """
+ with mocked_gateway:
+ # Delete the session (and kernel)
+ response = self.request('DELETE', '/api/sessions/' + session_id)
+ self.assertEqual(response.status_code, 204)
+ self.assertEqual(response.reason, 'No Content')
+
+ def is_kernel_running(self, kernel_id):
+ """Issues request to get the set of running kernels
+ """
+ with mocked_gateway:
+ # Get list of running kernels
+ response = self.request('GET', '/api/kernels')
+ self.assertEqual(response.status_code, 200)
+ kernels = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ self.assertEqual(len(kernels), len(running_kernels))
+ for model in kernels:
+ if model.get('id') == kernel_id:
+ return True
+ return False
+
+ def create_kernel(self, kernel_name):
+ """Issues request to retart the given kernel
+ """
+ with mocked_gateway:
+ kwargs = dict()
+ kwargs['json'] = {'name': kernel_name}
+
+ # add a KERNEL_ value to the current env and we'll ensure that that value exists in the mocked method
+ os.environ['KERNEL_KSPEC_NAME'] = kernel_name
+
+ response = self.request('POST', '/api/kernels', **kwargs)
+ self.assertEqual(response.status_code, 201)
+ model = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ kernel_id = model.get('id')
+ # ensure its in the running_kernels and name matches.
+ running_kernel = running_kernels.get(kernel_id)
+ self.assertEqual(kernel_id, running_kernel.get('id'))
+ self.assertEqual(model.get('name'), kernel_name)
+
+ # restore env
+ os.environ.pop('KERNEL_KSPEC_NAME')
+ return kernel_id
+
+ def interrupt_kernel(self, kernel_id):
+ """Issues request to interrupt the given kernel
+ """
+ with mocked_gateway:
+ response = self.request('POST', '/api/kernels/' + kernel_id + '/interrupt')
+ self.assertEqual(response.status_code, 204)
+ self.assertEqual(response.reason, 'No Content')
+
+ def restart_kernel(self, kernel_id):
+ """Issues request to retart the given kernel
+ """
+ with mocked_gateway:
+ response = self.request('POST', '/api/kernels/' + kernel_id + '/restart')
+ self.assertEqual(response.status_code, 200)
+ model = json.loads(response.content.decode('utf-8'), encoding='utf-8')
+ restarted_kernel_id = model.get('id')
+ # ensure its in the running_kernels and name matches.
+ running_kernel = running_kernels.get(restarted_kernel_id)
+ self.assertEqual(restarted_kernel_id, running_kernel.get('id'))
+ self.assertEqual(model.get('name'), running_kernel.get('name'))
+
+ def delete_kernel(self, kernel_id):
+ """Deletes kernel corresponding to the given kernel id.
+ """
+ with mocked_gateway:
+ # Delete the session (and kernel)
+ response = self.request('DELETE', '/api/kernels/' + kernel_id)
+ self.assertEqual(response.status_code, 204)
+ self.assertEqual(response.reason, 'No Content')
+