diff --git a/doltpy/cli/write.py b/doltpy/cli/write.py index 57d0617..1163e25 100644 --- a/doltpy/cli/write.py +++ b/doltpy/cli/write.py @@ -43,8 +43,10 @@ def write_pandas( """ def writer(filepath: str): - clean = df.dropna(subset=primary_key) - clean.to_csv(filepath, index=False) + filtered = df + if import_mode != "update": + filtered = df.dropna(subset=primary_key) + filtered.to_csv(filepath, index=False) return filepath _import_helper( diff --git a/doltpy/sql/sql.py b/doltpy/sql/sql.py index 40cd879..9028cf3 100644 --- a/doltpy/sql/sql.py +++ b/doltpy/sql/sql.py @@ -350,8 +350,8 @@ def get_query(table: str) -> str: def tables(self) -> List[str]: with self.engine.connect() as conn: - result = conn.execute("SHOW TABLES") - return [row["Table"] for row in result] + result = conn.execute("select table_name from information_schema.tables where table_schema = DATABASE();") + return [row["table_name"] for row in result] class DoltSQLEngineContext(DoltSQLContext): diff --git a/tests/cli/test_write.py b/tests/cli/test_write.py index 8d869f8..a74e95d 100644 --- a/tests/cli/test_write.py +++ b/tests/cli/test_write.py @@ -1,6 +1,7 @@ -from doltpy.cli.write import write_pandas, CREATE +from doltpy.cli.write import write_pandas, CREATE, UPDATE from doltpy.cli.read import read_rows from .helpers import compare_rows +import numpy as np import pandas as pd @@ -19,4 +20,16 @@ def test_write_pandas(init_empty_test_repo): compare_rows(expected, actual, "id") +def test_write_pandas_accept_nulls_on_update(init_empty_test_repo): + NULL_ROWS = [ + {"name": "Anna", "adjective": "tragic", "id": "1", "date_of_death": "1877-01-01"}, + {"name": "Vronksy", "adjective": "honorable", "id": "2", "date_of_death": ""}, + {"name": "Oblonsky", "adjective": "buffoon", "id": "3", "date_of_death": np.NaN}, + ] + dolt = init_empty_test_repo + write_pandas(dolt, "characters", pd.DataFrame(NULL_ROWS[:2]), CREATE, ["id"]) + write_pandas(dolt, "characters", pd.DataFrame(NULL_ROWS[2:]), UPDATE) + actual = read_rows(dolt, "characters") + expected = pd.DataFrame(NULL_ROWS).to_dict("records") + compare_rows(expected, actual, "id")