diff --git a/doltpy/sql/sql.py b/doltpy/sql/sql.py index d4a7473..72c272d 100644 --- a/doltpy/sql/sql.py +++ b/doltpy/sql/sql.py @@ -349,6 +349,7 @@ def tables(self) -> List[str]: class DoltSQLEngineContext(DoltSQLContext): def __init__(self, dolt: Dolt, server_config: ServerConfig): self.dolt = dolt + self.database = dolt.repo_name self.server_config = server_config self.engine = self._get_engine() self.verify_connection() diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index e97bae9..6ad7deb 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -1,7 +1,13 @@ import psutil +import os +import shutil +from subprocess import Popen +import tempfile +import time + import pytest from doltpy.cli import Dolt -from doltpy.sql import DoltSQLServerContext +from doltpy.sql import DoltSQLServerContext, DoltSQLEngineContext, ServerConfig from .helpers import TEST_SERVER_CONFIG, TEST_DATA_INITIAL TEST_TABLE_ONE, TEST_TABLE_TWO = 'foo', 'bar' @@ -63,3 +69,30 @@ def test_show_tables(with_test_tables): with DoltSQLServerContext(dolt, TEST_SERVER_CONFIG) as dssc: tables = dssc.tables() assert TEST_TABLE_ONE in tables and TEST_TABLE_TWO in tables + + +@pytest.fixture(scope="function") +def sql_server(): + p = None + d = tempfile.TemporaryDirectory() + try: + db_path = os.path.join(d.name, "tracks") + db = Dolt.init(db_path) + db.sql("create table tracks (TrackId bigint, Name text)") + db.sql("insert into tracks values (0, 'Sue'), (1, 'L'), (2, 'M'), (3, 'Ji'), (4, 'Po')") + db.sql("select dolt_commit('-am', 'Init tracks')") + p = Popen(args=["dolt", "sql-server", "-l", "trace", "--port", "3307"], cwd=db_path) + time.sleep(.5) + yield db + finally: + if p is not None: + p.kill() + if os.path.exists(d.name): + shutil.rmtree(d.name) + +def test_show_tables_engine(sql_server): + dolt = sql_server + conf = ServerConfig(user="root", host="localhost", port="3307") + conn = DoltSQLEngineContext(dolt, conf) + tables = conn.tables() + assert "tracks" in tables