From 6a38964677b9886dc898b4b63dfb578f96776f67 Mon Sep 17 00:00:00 2001 From: Sam Richards Date: Tue, 19 Dec 2023 18:11:46 +0000 Subject: [PATCH 1/4] Update dolt.py --- doltcli/dolt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doltcli/dolt.py b/doltcli/dolt.py index 150874d..a3a6a7a 100644 --- a/doltcli/dolt.py +++ b/doltcli/dolt.py @@ -267,6 +267,8 @@ class Dolt(DoltT): """ def __init__(self, repo_dir: str, print_output: Optional[bool] = None): + # allow ~ to be used in paths + repo_dir = os.path.expanduser(repo_dir) self.repo_dir = repo_dir self._print_output = print_output or False From 952035eb5211e0497a7c18399dcbc5835c81b1be Mon Sep 17 00:00:00 2001 From: Sam Richards Date: Fri, 5 Jan 2024 10:11:46 -0500 Subject: [PATCH 2/4] Update schema_export to match cli the --filename flag appears to have been removed from the cli --- doltcli/dolt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doltcli/dolt.py b/doltcli/dolt.py index a3a6a7a..1cfacf7 100644 --- a/doltcli/dolt.py +++ b/doltcli/dolt.py @@ -1368,7 +1368,7 @@ def schema_export(self, table: str, filename: Optional[str] = None): args = ["schema", "export", table] if filename: - args.extend(["--filename", filename]) + args.extend([filename]) _execute(args, self.repo_dir) return True else: From 9ef8c252ccb6166ab54483286d4a2780be6e78d0 Mon Sep 17 00:00:00 2001 From: Sam Richards Date: Fri, 5 Jan 2024 12:51:12 -0500 Subject: [PATCH 3/4] Add test --- tests/test_dolt.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_dolt.py b/tests/test_dolt.py index e363660..6f35705 100644 --- a/tests/test_dolt.py +++ b/tests/test_dolt.py @@ -111,6 +111,16 @@ def test_init(tmp_path): Dolt.init(repo_path) assert os.path.exists(repo_data_dir) shutil.rmtree(repo_data_dir) + +def test_home_path(): + path = "~/.dolt_test" + os.mkdir(path) + + repo_path, repo_data_dir = get_repo_path_tmp_path("path") + assert not os.path.exists(repo_data_dir) + Dolt.init(repo_path) + assert os.path.exists(repo_data_dir) + shutil.rmtree(repo_data_dir) def test_bad_repo_path(tmp_path): From d760f767db68c45fc7345477f4ec385959693e80 Mon Sep 17 00:00:00 2001 From: Sam Richards Date: Fri, 5 Jan 2024 13:33:23 -0500 Subject: [PATCH 4/4] Fix the test --- tests/test_dolt.py | 58 +++++++++++++++++----------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/tests/test_dolt.py b/tests/test_dolt.py index 6f35705..6a3c628 100644 --- a/tests/test_dolt.py +++ b/tests/test_dolt.py @@ -111,16 +111,18 @@ def test_init(tmp_path): Dolt.init(repo_path) assert os.path.exists(repo_data_dir) shutil.rmtree(repo_data_dir) - + + def test_home_path(): path = "~/.dolt_test" - os.mkdir(path) - - repo_path, repo_data_dir = get_repo_path_tmp_path("path") - assert not os.path.exists(repo_data_dir) - Dolt.init(repo_path) - assert os.path.exists(repo_data_dir) - shutil.rmtree(repo_data_dir) + if os.path.exists(os.path.expanduser(path)): + shutil.rmtree(os.path.expanduser(path)) + os.mkdir(os.path.expanduser(path)) + # Create empty file + open(os.path.expanduser(path + "/.dolt"), "a").close() + Dolt(path) + assert os.path.exists(path) + shutil.rmtree(path) def test_bad_repo_path(tmp_path): @@ -215,10 +217,10 @@ def test_merge_conflict(create_test_table: Tuple[Dolt, str]): with pytest.raises(DoltException): repo.merge("other", message_merge) - #commits = list(repo.log().values()) - #head_of_main = commits[0] + # commits = list(repo.log().values()) + # head_of_main = commits[0] - #assert head_of_main.message == message_two + # assert head_of_main.message == message_two def test_dolt_log(create_test_table: Tuple[Dolt, str]): @@ -410,10 +412,7 @@ def test_branch(create_test_table: Tuple[Dolt, str]): repo.checkout("dosac", checkout_branch=True) repo.checkout("main") next_active_branch, next_branches = repo.branch() - assert ( - set(branch.name for branch in next_branches) == {"main", "dosac"} - and next_active_branch.name == "main" - ) + assert set(branch.name for branch in next_branches) == {"main", "dosac"} and next_active_branch.name == "main" repo.checkout("dosac") different_active_branch, _ = repo.branch() @@ -562,17 +561,13 @@ def test_sql(create_test_table: Tuple[Dolt, str]): def test_sql_json(create_test_table: Tuple[Dolt, str]): repo, test_table = create_test_table - result = repo.sql( - query="SELECT * FROM `{table}`".format(table=test_table), result_format="json" - )["rows"] + result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="json")["rows"] _verify_against_base_rows(result) def test_sql_csv(create_test_table: Tuple[Dolt, str]): repo, test_table = create_test_table - result = repo.sql( - query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv" - ) + result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv") _verify_against_base_rows(result) @@ -614,10 +609,7 @@ def test_config_global(init_empty_test_repo: Dolt): Dolt.config_global(add=True, name="user.name", value=test_username) Dolt.config_global(add=True, name="user.email", value=test_email) updated_config = Dolt.config_global(list=True) - assert ( - updated_config["user.name"] == test_username - and updated_config["user.email"] == test_email - ) + assert updated_config["user.name"] == test_username and updated_config["user.email"] == test_email Dolt.config_global(add=True, name="user.name", value=current_global_config["user.name"]) Dolt.config_global(add=True, name="user.email", value=current_global_config["user.email"]) reset_config = Dolt.config_global(list=True) @@ -633,9 +625,7 @@ def test_config_local(init_empty_test_repo: Dolt): repo.config_local(add=True, name="user.email", value=test_email) local_config = repo.config_local(list=True) global_config = Dolt.config_global(list=True) - assert ( - local_config["user.name"] == test_username and local_config["user.email"] == test_email - ) + assert local_config["user.name"] == test_username and local_config["user.email"] == test_email assert global_config["user.name"] == current_global_config["user.name"] assert global_config["user.email"] == current_global_config["user.email"] @@ -687,18 +677,14 @@ def test_clone_new_dir(tmp_path): def test_dolt_sql_csv(init_empty_test_repo: Dolt): dolt = init_empty_test_repo write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv" - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv") compare_rows_helper(BASE_TEST_ROWS, result) def test_dolt_sql_json(init_empty_test_repo: Dolt): dolt = init_empty_test_repo write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ", result_format="json" - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_format="json") # JSON return value preserves some type information, we cast back to a string for row in result["rows"]: row["id"] = str(row["id"]) @@ -710,9 +696,7 @@ def test_dolt_sql_file(init_empty_test_repo: Dolt): with tempfile.NamedTemporaryFile() as f: write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name) res = read_csv_to_dict(f.name) compare_rows_helper(BASE_TEST_ROWS, res)