diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 27ac3c59..91a6598d 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -146,15 +146,14 @@ def fetch(self, sql: str) -> Iterator[Row]: return self._sql.fetch_all(sql) def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): - if mode == "overwrite": - msg = "Overwrite mode is not yet supported" - raise NotImplementedError(msg) rows = self._filter_none_rows(rows, klass) self.create_table(full_name, klass) if len(rows) == 0: return fields = dataclasses.fields(klass) field_names = [f.name for f in fields] + if mode == "overwrite": + self.execute(f"TRUNCATE TABLE {full_name}") for i in range(0, len(rows), self._max_records_per_batch): batch = rows[i : i + self._max_records_per_batch] vals = "), (".join(self._row_to_sql(r, fields) for r in batch) @@ -283,10 +282,9 @@ def fetch(self, sql) -> Iterator[Row]: return iter(rows) def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): - if mode == "overwrite": - msg = "Overwrite mode is not yet supported" - raise NotImplementedError(msg) rows = self._filter_none_rows(rows, klass) + if mode == "overwrite": + self._save_table = [] if klass.__class__ == type: row_factory = self._row_factory(klass) rows = [row_factory(*dataclasses.astuple(r)) for r in rows] diff --git a/tests/integration/test_deployment.py b/tests/integration/test_deployment.py index 55c83eef..b7fcb615 100644 --- a/tests/integration/test_deployment.py +++ b/tests/integration/test_deployment.py @@ -22,3 +22,14 @@ def test_deploys_database(ws, env_or_skip, make_random): rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some")) assert rows == [Row(name="abc", id=1)] + + +def test_overwrite(ws, env_or_skip, make_random): + schema = "default" + sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) + + sql_backend.save_table(f"{schema}.foo", [views.Foo("abc", True)], views.Foo, "append") + sql_backend.save_table(f"{schema}.foo", [views.Foo("xyz", True)], views.Foo, "overwrite") + rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some")) + + assert rows == [Row(name="xyz", id=1)] diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 0c7bd726..54ece561 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -93,10 +93,47 @@ def test_statement_execution_backend_fetch_happy(): assert [Row(id=1), Row(id=2), Row(id=3)] == result -def test_statement_execution_backend_save_table_overwrite(mocker): - seb = StatementExecutionBackend(mocker.Mock(), "abc") - with pytest.raises(NotImplementedError): - seb.save_table("a.b.c", [1, 2, 3], Bar, mode="overwrite") +def test_statement_execution_backend_save_table_overwrite_empty_table(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + seb = StatementExecutionBackend(ws, "abc") + seb.save_table("a.b.c", [Baz("1")], Baz, mode="overwrite") + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second STRING) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="TRUNCATE TABLE a.b.c", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) def test_statement_execution_backend_save_table_empty_records(): @@ -357,3 +394,16 @@ def test_mock_backend_rows_dsl(): Row(foo=1, bar=2), Row(foo=3, bar=4), ] + + +def test_mock_backend_overwrite(): + mock_backend = MockBackend() + mock_backend.save_table("a.b.c", [Foo("a1", True), Foo("c2", False)], Foo, "append") + mock_backend.save_table("a.b.c", [Foo("aa", True), Foo("bb", False)], Foo, "overwrite") + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, "overwrite") + + assert mock_backend.rows_written_for("a.b.c", "append") == [] + assert mock_backend.rows_written_for("a.b.c", "overwrite") == [ + Row(first="aaa", second=True), + Row(first="bbb", second=False), + ]