diff --git a/src/qcodes/dataset/sqlite/connection.py b/src/qcodes/dataset/sqlite/connection.py index b919e26dada..5dd20bba494 100644 --- a/src/qcodes/dataset/sqlite/connection.py +++ b/src/qcodes/dataset/sqlite/connection.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +import sqlite3 from contextlib import contextmanager from typing import TYPE_CHECKING, Any @@ -15,7 +16,6 @@ from qcodes.utils import DelayedKeyboardInterrupt if TYPE_CHECKING: - import sqlite3 from collections.abc import Iterator log = logging.getLogger(__name__) @@ -57,6 +57,36 @@ def __init__(self, sqlite3_connection: sqlite3.Connection): self.path_to_dbfile = path_to_dbfile(sqlite3_connection) +class ConnectionPlusPlus(sqlite3.Connection): + """ + A class to extend the sqlite3.Connection object. Since sqlite3.Connection + has no __dict__, we can not directly add attributes to its instance + directly. + + It is not allowed to instantiate a new `ConnectionPlus` object from a + `ConnectionPlus` object. + + It is recommended to create a ConnectionPlus using the function :func:`connect` + + """ + + atomic_in_progress: bool = False + """ + a bool describing whether the connection is + currently in the middle of an atomic block of transactions, thus + allowing to nest `atomic` context managers + """ + path_to_dbfile: str = "" + """ + Path to the database file of the connection. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.path_to_dbfile = path_to_dbfile(self) + + def make_connection_plus_from( conn: sqlite3.Connection | ConnectionPlus, ) -> ConnectionPlus: @@ -97,7 +127,7 @@ def atomic(conn: ConnectionPlus) -> Iterator[ConnectionPlus]: """ with DelayedKeyboardInterrupt(context={"reason": "sqlite atomic operation"}): - if not isinstance(conn, ConnectionPlus): + if not isinstance(conn, ConnectionPlus | ConnectionPlusPlus): raise ValueError( "atomic context manager only accepts " "ConnectionPlus database connection objects." diff --git a/src/qcodes/dataset/sqlite/database.py b/src/qcodes/dataset/sqlite/database.py index a9644ea383e..4573f062a92 100644 --- a/src/qcodes/dataset/sqlite/database.py +++ b/src/qcodes/dataset/sqlite/database.py @@ -18,7 +18,7 @@ import qcodes from qcodes.dataset.experiment_settings import reset_default_experiment_id -from qcodes.dataset.sqlite.connection import ConnectionPlus +from qcodes.dataset.sqlite.connection import ConnectionPlus, ConnectionPlusPlus from qcodes.dataset.sqlite.db_upgrades import ( _latest_available_version, perform_db_upgrade, @@ -119,7 +119,9 @@ def _adapt_complex(value: complex | np.complexfloating) -> sqlite3.Binary: return sqlite3.Binary(out.read()) -def connect(name: str | Path, debug: bool = False, version: int = -1) -> ConnectionPlus: +def connect( + name: str | Path, debug: bool = False, version: int = -1 +) -> ConnectionPlusPlus: """ Connect or create database. If debug the queries will be echoed back. This function takes care of registering the numpy/sqlite type @@ -141,10 +143,12 @@ def connect(name: str | Path, debug: bool = False, version: int = -1) -> Connect # register binary(TEXT) -> numpy converter sqlite3.register_converter("array", _convert_array) - sqlite3_conn = sqlite3.connect( - name, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=True + conn = sqlite3.connect( + name, + detect_types=sqlite3.PARSE_DECLTYPES, + check_same_thread=True, + factory=ConnectionPlusPlus, ) - conn = ConnectionPlus(sqlite3_conn) latest_supported_version = _latest_available_version() db_version = get_user_version(conn) diff --git a/tests/dataset/test_database_creation_and_upgrading.py b/tests/dataset/test_database_creation_and_upgrading.py index e35bf8e7fde..8a036373f2c 100644 --- a/tests/dataset/test_database_creation_and_upgrading.py +++ b/tests/dataset/test_database_creation_and_upgrading.py @@ -11,7 +11,6 @@ import qcodes.dataset.descriptions.versioning.serialization as serial import tests.dataset from qcodes.dataset import ( - ConnectionPlus, connect, initialise_database, initialise_or_create_database_at, @@ -26,7 +25,7 @@ from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.versioning.v0 import InterDependencies from qcodes.dataset.guids import parse_guid -from qcodes.dataset.sqlite.connection import atomic_transaction +from qcodes.dataset.sqlite.connection import ConnectionPlusPlus, atomic_transaction from qcodes.dataset.sqlite.database import get_db_version_and_newest_available_version from qcodes.dataset.sqlite.db_upgrades import ( _latest_available_version, @@ -703,7 +702,7 @@ def test_perform_actual_upgrade_6_to_7() -> None: skip_if_no_fixtures(dbname_old) with temporarily_copied_DB(dbname_old, debug=False, version=6) as conn: - assert isinstance(conn, ConnectionPlus) + assert isinstance(conn, ConnectionPlusPlus) perform_db_upgrade_6_to_7(conn) assert get_user_version(conn) == 7 @@ -762,7 +761,7 @@ def test_perform_actual_upgrade_6_to_newest_add_new_data() -> None: skip_if_no_fixtures(dbname_old) with temporarily_copied_DB(dbname_old, debug=False, version=6) as conn: - assert isinstance(conn, ConnectionPlus) + assert isinstance(conn, ConnectionPlusPlus) perform_db_upgrade(conn) assert get_user_version(conn) >= 7 no_of_runs_query = "SELECT max(run_id) FROM runs"