Skip to content

Commit

Permalink
Fix engine context and add test (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman authored Apr 21, 2021
1 parent 2c0e198 commit d34516f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
1 change: 1 addition & 0 deletions doltpy/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def tables(self) -> List[str]:
class DoltSQLEngineContext(DoltSQLContext):
def __init__(self, dolt: Dolt, server_config: ServerConfig):
self.dolt = dolt
self.database = dolt.repo_name
self.server_config = server_config
self.engine = self._get_engine()
self.verify_connection()
Expand Down
35 changes: 34 additions & 1 deletion tests/sql/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import psutil
import os
import shutil
from subprocess import Popen
import tempfile
import time

import pytest
from doltpy.cli import Dolt
from doltpy.sql import DoltSQLServerContext
from doltpy.sql import DoltSQLServerContext, DoltSQLEngineContext, ServerConfig
from .helpers import TEST_SERVER_CONFIG, TEST_DATA_INITIAL

TEST_TABLE_ONE, TEST_TABLE_TWO = 'foo', 'bar'
Expand Down Expand Up @@ -63,3 +69,30 @@ def test_show_tables(with_test_tables):
with DoltSQLServerContext(dolt, TEST_SERVER_CONFIG) as dssc:
tables = dssc.tables()
assert TEST_TABLE_ONE in tables and TEST_TABLE_TWO in tables


@pytest.fixture(scope="function")
def sql_server():
p = None
d = tempfile.TemporaryDirectory()
try:
db_path = os.path.join(d.name, "tracks")
db = Dolt.init(db_path)
db.sql("create table tracks (TrackId bigint, Name text)")
db.sql("insert into tracks values (0, 'Sue'), (1, 'L'), (2, 'M'), (3, 'Ji'), (4, 'Po')")
db.sql("select dolt_commit('-am', 'Init tracks')")
p = Popen(args=["dolt", "sql-server", "-l", "trace", "--port", "3307"], cwd=db_path)
time.sleep(.5)
yield db
finally:
if p is not None:
p.kill()
if os.path.exists(d.name):
shutil.rmtree(d.name)

def test_show_tables_engine(sql_server):
dolt = sql_server
conf = ServerConfig(user="root", host="localhost", port="3307")
conn = DoltSQLEngineContext(dolt, conf)
tables = conn.tables()
assert "tracks" in tables

0 comments on commit d34516f

Please sign in to comment.