From 128a0fa1315a27719d38e3efe5efcff84f1787df Mon Sep 17 00:00:00 2001 From: Andrew Snare Date: Wed, 2 Oct 2024 19:58:10 +0200 Subject: [PATCH] Minor improvements for `.save_table(mode="overwrite")` (#298) This PR includes some updates for `.save_table()` in overwrite mode: - If no rows are supplied, instead of a no-op the table is truncated. The mock backend already assumed this but diverged from the concrete implementations which treated this as a no-op (as when appending). Unit tests now cover this situation for all backends, and the existing integration test for the SQL-statement backend has been updated to cover this. - The SQL-based backends have a slight optimisation: instead of first truncating before inserting the truncate is now performed as part of the insert for the first batch. - Type hints on the abstract method now match the concrete implementations. --- src/databricks/labs/lsql/backends.py | 22 ++- tests/integration/test_backends.py | 5 + tests/unit/test_backends.py | 165 +++++++++++++++++-- tests/unit/test_command_execution_backend.py | 87 ++++++++-- 4 files changed, 244 insertions(+), 35 deletions(-) diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 52378224..00a9409d 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -135,20 +135,29 @@ def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Any]: raise NotImplementedError - def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): + def save_table( + self, + full_name: str, + rows: Sequence[DataclassInstance], + klass: Dataclass, + mode: Literal["append", "overwrite"] = "append", + ): rows = self._filter_none_rows(rows, klass) self.create_table(full_name, klass) - if len(rows) == 0: + if not rows: + if mode == "overwrite": + self.execute(f"TRUNCATE TABLE {full_name}") return fields = dataclasses.fields(klass) field_names = [f.name for f in fields] - if mode == "overwrite": - self.execute(f"TRUNCATE TABLE {full_name}") + insert_modifier = "OVERWRITE" if mode == "overwrite" else "INTO" for i in range(0, len(rows), self._max_records_per_batch): batch = rows[i : i + self._max_records_per_batch] vals = "), (".join(self._row_to_sql(r, fields) for r in batch) - sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})' + sql = f'INSERT {insert_modifier} {full_name} ({", ".join(field_names)}) VALUES ({vals})' self.execute(sql) + # Only the first batch can truncate; subsequent batches append. + insert_modifier = "INTO" @classmethod def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]): @@ -277,8 +286,7 @@ def save_table( mode: Literal["append", "overwrite"] = "append", ) -> None: rows = self._filter_none_rows(rows, klass) - - if len(rows) == 0: + if not rows and mode == "append": self.create_table(full_name, klass) return # pyspark deals well with lists of dataclass instances, as long as schema is provided diff --git a/tests/integration/test_backends.py b/tests/integration/test_backends.py index 6930aec2..bedd72fe 100644 --- a/tests/integration/test_backends.py +++ b/tests/integration/test_backends.py @@ -162,6 +162,11 @@ def test_statement_execution_backend_overwrites_table(ws, env_or_skip, make_rand rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo")) assert rows == [Row(first="xyz", second=True)] + sql_backend.save_table(f"{catalog}.{schema}.foo", [], views.Foo, "overwrite") + + rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo")) + assert rows == [] + def test_runtime_backend_use_statements(ws): product_info = ProductInfo.for_testing(SqlBackend) diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 71db5207..a1ce0391 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -2,6 +2,7 @@ import os import sys from dataclasses import dataclass +from typing import Literal from unittest import mock from unittest.mock import MagicMock, call, create_autospec @@ -137,17 +138,7 @@ def test_statement_execution_backend_save_table_overwrite_empty_table(): ), mock.call( warehouse_id="abc", - statement="TRUNCATE TABLE a.b.c", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - mock.call( - warehouse_id="abc", - statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)", + statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('1', NULL)", catalog=None, schema=None, disposition=None, @@ -170,7 +161,7 @@ def test_statement_execution_backend_save_table_empty_records(): seb.save_table("a.b.c", [], Bar) - ws.statement_execution.execute_statement.assert_called_with( + ws.statement_execution.execute_statement.assert_called_once_with( warehouse_id="abc", statement="CREATE TABLE IF NOT EXISTS a.b.c " "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", @@ -183,6 +174,44 @@ def test_statement_execution_backend_save_table_empty_records(): ) +def test_statement_execution_backend_save_table_overwrite_empty_records() -> None: + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = StatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.save_table("a.b.c", [], Bar, mode="overwrite") + + ws.statement_execution.execute_statement.assert_has_calls( + [ + call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c " + "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + call( + warehouse_id="abc", + statement="TRUNCATE TABLE a.b.c", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) + + def test_statement_execution_backend_save_table_two_records(): ws = create_autospec(WorkspaceClient) @@ -220,7 +249,7 @@ def test_statement_execution_backend_save_table_two_records(): ) -def test_statement_execution_backend_save_table_in_batches_of_two(): +def test_statement_execution_backend_save_table_append_in_batches_of_two() -> None: ws = create_autospec(WorkspaceClient) ws.statement_execution.execute_statement.return_value = StatementResponse( @@ -229,7 +258,7 @@ def test_statement_execution_backend_save_table_in_batches_of_two(): seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) - seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo) + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="append") ws.statement_execution.execute_statement.assert_has_calls( [ @@ -267,6 +296,53 @@ def test_statement_execution_backend_save_table_in_batches_of_two(): ) +def test_statement_execution_backend_save_table_overwrite_in_batches_of_two() -> None: + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = StatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) + + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="overwrite") + + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) + + def test_runtime_backend_execute(): with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = MagicMock() @@ -298,7 +374,8 @@ def test_runtime_backend_fetch(): spark.sql.assert_has_calls(calls) -def test_runtime_backend_save_table(): +@pytest.mark.parametrize("mode", ["append", "overwrite"]) +def test_runtime_backend_save_table(mode: Literal["append", "overwrite"]) -> None: with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = MagicMock() sys.modules["pyspark.sql.session"] = pyspark_sql_session @@ -306,13 +383,44 @@ def test_runtime_backend_save_table(): runtime_backend = RuntimeBackend() - runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, mode=mode) - spark.createDataFrame.assert_called_with( + spark.createDataFrame.assert_called_once_with( [Foo(first="aaa", second=True), Foo(first="bbb", second=False)], "first STRING NOT NULL, second BOOLEAN NOT NULL", ) - spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode=mode) + + +def test_runtime_backend_save_table_append_empty_records() -> None: + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [], Foo, mode="append") + + spark.createDataFrame.assert_not_called() + spark.createDataFrame().write.saveAsTable.assert_not_called() + spark.sql.assert_called_once_with( + "CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA" + ) + + +def test_runtime_backend_save_table_overwrite_empty_records() -> None: + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [], Foo, mode="overwrite") + + spark.createDataFrame.assert_called_once_with([], "first STRING NOT NULL, second BOOLEAN NOT NULL") + spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode="overwrite") def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker): @@ -427,6 +535,27 @@ def test_mock_backend_save_table_overwrite() -> None: ] +def test_mock_backend_save_table_no_rows() -> None: + mock_backend = MockBackend() + + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + mock_backend.save_table("a.b.c", [], Foo) + + assert mock_backend.rows_written_for("a.b.c", mode="append") == [ + Row(first="aaa", second=True), + Row(first="bbb", second=False), + ] + + +def test_mock_backend_save_table_overwrite_no_rows() -> None: + mock_backend = MockBackend() + + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + mock_backend.save_table("a.b.c", [], Foo) + + assert mock_backend.rows_written_for("a.b.c", mode="overwrite") == [] + + def test_mock_backend_rows_dsl(): rows = MockBackend.rows("foo", "bar")[ [1, 2], diff --git a/tests/unit/test_command_execution_backend.py b/tests/unit/test_command_execution_backend.py index cf7cc148..1feabe07 100644 --- a/tests/unit/test_command_execution_backend.py +++ b/tests/unit/test_command_execution_backend.py @@ -128,13 +128,7 @@ def test_command_context_backend_save_table_overwrite_empty_table(): cluster_id="abc", language=Language.SQL, context_id="abc", - command="TRUNCATE TABLE a.b.c", - ), - mock.call( - cluster_id="abc", - language=Language.SQL, - context_id="abc", - command="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)", + command="INSERT OVERWRITE a.b.c (first, second) VALUES ('1', NULL)", ), ] ) @@ -155,7 +149,7 @@ def test_command_context_backend_save_table_empty_records(): ceb.save_table("a.b.c", [], Bar) - ws.command_execution.execute.assert_called_with( + ws.command_execution.execute.assert_called_once_with( cluster_id="abc", language=Language.SQL, context_id="abc", @@ -164,6 +158,40 @@ def test_command_context_backend_save_table_empty_records(): ) +def test_command_context_backend_save_table_overwrite_empty_records() -> None: + ws = create_autospec(WorkspaceClient) + ws.command_execution.create.return_value = Wait[ContextStatusResponse]( + waiter=lambda callback, timeout: ContextStatusResponse(id="abc") + ) + ws.command_execution.execute.return_value = Wait[CommandStatusResponse]( + waiter=lambda callback, timeout: CommandStatusResponse( + results=Results(data="success"), status=CommandStatus.FINISHED + ) + ) + + ceb = CommandExecutionBackend(ws, "abc") + + ceb.save_table("a.b.c", [], Bar, mode="overwrite") + + ws.command_execution.execute.assert_has_calls( + [ + call( + cluster_id="abc", + language=Language.SQL, + context_id="abc", + command="CREATE TABLE IF NOT EXISTS a.b.c " + "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", + ), + call( + cluster_id="abc", + language=Language.SQL, + context_id="abc", + command="TRUNCATE TABLE a.b.c", + ), + ] + ) + + def test_command_context_backend_save_table_two_records(): ws = create_autospec(WorkspaceClient) ws.command_execution.create.return_value = Wait[ContextStatusResponse]( @@ -197,7 +225,7 @@ def test_command_context_backend_save_table_two_records(): ) -def test_command_context_backend_save_table_in_batches_of_two(mocker): +def test_command_context_backend_save_table_append_in_batches_of_two() -> None: ws = create_autospec(WorkspaceClient) ws.command_execution.create.return_value = Wait[ContextStatusResponse]( waiter=lambda callback, timeout: ContextStatusResponse(id="abc") @@ -210,7 +238,7 @@ def test_command_context_backend_save_table_in_batches_of_two(mocker): ceb = CommandExecutionBackend(ws, "abc", max_records_per_batch=2) - ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo) + ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="append") ws.command_execution.execute.assert_has_calls( [ @@ -234,3 +262,42 @@ def test_command_context_backend_save_table_in_batches_of_two(mocker): ), ] ) + + +def test_command_context_backend_save_table_overwrite_in_batches_of_two() -> None: + ws = create_autospec(WorkspaceClient) + ws.command_execution.create.return_value = Wait[ContextStatusResponse]( + waiter=lambda callback, timeout: ContextStatusResponse(id="abc") + ) + ws.command_execution.execute.return_value = Wait[CommandStatusResponse]( + waiter=lambda callback, timeout: CommandStatusResponse( + results=Results(data="success"), status=CommandStatus.FINISHED + ) + ) + + ceb = CommandExecutionBackend(ws, "abc", max_records_per_batch=2) + + ceb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="overwrite") + + ws.command_execution.execute.assert_has_calls( + [ + mock.call( + cluster_id="abc", + language=Language.SQL, + context_id="abc", + command="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + ), + mock.call( + cluster_id="abc", + language=Language.SQL, + context_id="abc", + command="INSERT OVERWRITE a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + ), + mock.call( + cluster_id="abc", + language=Language.SQL, + context_id="abc", + command="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", + ), + ] + )