diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index bc441a5567..dbf4faad6f 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -1,6 +1,7 @@ """A base class session manager.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import pathlib import uuid try: @@ -13,6 +14,9 @@ from traitlets.config.configurable import LoggingConfigurable from traitlets import Instance +from traitlets import Unicode +from traitlets import validate +from traitlets import TraitError from jupyter_server.utils import ensure_async from jupyter_server.traittypes import InstanceFromClasses @@ -20,6 +24,36 @@ class SessionManager(LoggingConfigurable): + database_filepath = Unicode( + default_value=":memory:", + help=( + "Th filesystem path to SQLite Database file " + "(e.g. /path/to/session_database.db). By default, the session " + "database is stored in-memory (i.e. `:memory:` setting from sqlite3) " + "and does not persist when the current Jupyter Server shuts down." + ), + ).tag(config=True) + + @validate("database_filepath") + def _validate_database_filepath(self, proposal): + value = proposal["value"] + if value == ":memory:": + return value + path = pathlib.Path(value) + if path.exists(): + # Verify that the database path is not a directory. + if path.is_dir(): + raise TraitError( + "`database_filepath` expected a file path, but the given path is a directory." + ) + # Verify that database path is an SQLite 3 Database by checking its header. + with open(value, "rb") as f: + header = f.read(100) + + if not header.startswith(b"SQLite format 3") and not header == b"": + raise TraitError("The given file is not an SQLite database file.") + return value + kernel_manager = Instance("jupyter_server.services.kernels.kernelmanager.MappingKernelManager") contents_manager = InstanceFromClasses( [ @@ -39,7 +73,7 @@ def cursor(self): if self._cursor is None: self._cursor = self.connection.cursor() self._cursor.execute( - """CREATE TABLE session + """CREATE TABLE IF NOT EXISTS session (session_id, path, name, type, kernel_id)""" ) return self._cursor @@ -48,7 +82,8 @@ def cursor(self): def connection(self): """Start a database connection""" if self._connection is None: - self._connection = sqlite3.connect(":memory:") + # Set isolation level to None to autocommit all changes to the database. + self._connection = sqlite3.connect(self.database_filepath, isolation_level=None) self._connection.row_factory = sqlite3.Row return self._connection diff --git a/jupyter_server/tests/services/sessions/test_manager.py b/jupyter_server/tests/services/sessions/test_manager.py index 97af3175c4..3ca8da8df1 100644 --- a/jupyter_server/tests/services/sessions/test_manager.py +++ b/jupyter_server/tests/services/sessions/test_manager.py @@ -1,5 +1,6 @@ import pytest from tornado import web +from traitlets import TraitError from jupyter_server._tz import isoformat from jupyter_server._tz import utcnow @@ -264,3 +265,101 @@ async def test_bad_delete_session(session_manager): await session_manager.delete_session(bad_kwarg="23424") # Bad keyword with pytest.raises(web.HTTPError): await session_manager.delete_session(session_id="23424") # nonexistent + + +async def test_bad_database_filepath(jp_runtime_dir): + kernel_manager = DummyMKM() + + # Try to write to a path that's a directory, not a file. + path_id_directory = str(jp_runtime_dir) + # Should raise an error because the path is a directory. + with pytest.raises(TraitError) as err: + SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(path_id_directory), + ) + + # Try writing to file that's not a valid SQLite 3 database file. + non_db_file = jp_runtime_dir.joinpath("non_db_file.db") + non_db_file.write_bytes(b"this is a bad file") + + # Should raise an error because the file doesn't + # start with an SQLite database file header. + with pytest.raises(TraitError) as err: + SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(non_db_file), + ) + + +async def test_good_database_filepath(jp_runtime_dir): + kernel_manager = DummyMKM() + + # Try writing to an empty file. + empty_file = jp_runtime_dir.joinpath("empty.db") + empty_file.write_bytes(b"") + + session_manager = SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(empty_file), + ) + + await session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + # Assert that the database file exists + assert empty_file.exists() + + # Close the current session manager + del session_manager + + # Try writing to a file that already exists. + session_manager = SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(empty_file), + ) + + assert session_manager.database_filepath == str(empty_file) + + +async def test_session_persistence(jp_runtime_dir): + session_db_path = jp_runtime_dir.joinpath("test-session.db") + # Kernel manager needs to persist. + kernel_manager = DummyMKM() + + # Initialize a session and start a connection. + # This should create the session database the first time. + session_manager = SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(session_db_path), + ) + + session = await session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + + # Assert that the database file exists + assert session_db_path.exists() + + with open(session_db_path, "rb") as f: + header = f.read(100) + + assert header.startswith(b"SQLite format 3") + + # Close the current session manager + del session_manager + + # Get a new session_manager + session_manager = SessionManager( + kernel_manager=kernel_manager, + contents_manager=ContentsManager(), + database_filepath=str(session_db_path), + ) + + # Assert that the session database persists. + session = await session_manager.get_session(session_id=session["id"])