Skip to content

Commit

Permalink
Minor improvements for .save_table(mode="overwrite") (#298)
Browse files Browse the repository at this point in the history
This PR includes some updates for `.save_table()` in overwrite mode:

- If no rows are supplied, instead of a no-op the table is truncated.
The mock backend already assumed this but diverged from the concrete
implementations which treated this as a no-op (as when appending). Unit
tests now cover this situation for all backends, and the existing
integration test for the SQL-statement backend has been updated to cover
this.
- The SQL-based backends have a slight optimisation: instead of first
truncating before inserting the truncate is now performed as part of the
insert for the first batch.
- Type hints on the abstract method now match the concrete
implementations.
  • Loading branch information
asnare authored Oct 2, 2024
1 parent dbfc823 commit 128a0fa
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 35 deletions.
22 changes: 15 additions & 7 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,29 @@ def execute(self, sql: str, *, catalog: str | None = None, schema: str | None =
def fetch(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> Iterator[Any]:
raise NotImplementedError

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
def save_table(
self,
full_name: str,
rows: Sequence[DataclassInstance],
klass: Dataclass,
mode: Literal["append", "overwrite"] = "append",
):
rows = self._filter_none_rows(rows, klass)
self.create_table(full_name, klass)
if len(rows) == 0:
if not rows:
if mode == "overwrite":
self.execute(f"TRUNCATE TABLE {full_name}")
return
fields = dataclasses.fields(klass)
field_names = [f.name for f in fields]
if mode == "overwrite":
self.execute(f"TRUNCATE TABLE {full_name}")
insert_modifier = "OVERWRITE" if mode == "overwrite" else "INTO"
for i in range(0, len(rows), self._max_records_per_batch):
batch = rows[i : i + self._max_records_per_batch]
vals = "), (".join(self._row_to_sql(r, fields) for r in batch)
sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})'
sql = f'INSERT {insert_modifier} {full_name} ({", ".join(field_names)}) VALUES ({vals})'
self.execute(sql)
# Only the first batch can truncate; subsequent batches append.
insert_modifier = "INTO"

@classmethod
def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]):
Expand Down Expand Up @@ -277,8 +286,7 @@ def save_table(
mode: Literal["append", "overwrite"] = "append",
) -> None:
rows = self._filter_none_rows(rows, klass)

if len(rows) == 0:
if not rows and mode == "append":
self.create_table(full_name, klass)
return
# pyspark deals well with lists of dataclass instances, as long as schema is provided
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def test_statement_execution_backend_overwrites_table(ws, env_or_skip, make_rand
rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo"))
assert rows == [Row(first="xyz", second=True)]

sql_backend.save_table(f"{catalog}.{schema}.foo", [], views.Foo, "overwrite")

rows = list(sql_backend.fetch(f"SELECT * FROM {catalog}.{schema}.foo"))
assert rows == []


def test_runtime_backend_use_statements(ws):
product_info = ProductInfo.for_testing(SqlBackend)
Expand Down
165 changes: 147 additions & 18 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
from dataclasses import dataclass
from typing import Literal
from unittest import mock
from unittest.mock import MagicMock, call, create_autospec

Expand Down Expand Up @@ -137,17 +138,7 @@ def test_statement_execution_backend_save_table_overwrite_empty_table():
),
mock.call(
warehouse_id="abc",
statement="TRUNCATE TABLE a.b.c",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
mock.call(
warehouse_id="abc",
statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('1', NULL)",
catalog=None,
schema=None,
disposition=None,
Expand All @@ -170,7 +161,7 @@ def test_statement_execution_backend_save_table_empty_records():

seb.save_table("a.b.c", [], Bar)

ws.statement_execution.execute_statement.assert_called_with(
ws.statement_execution.execute_statement.assert_called_once_with(
warehouse_id="abc",
statement="CREATE TABLE IF NOT EXISTS a.b.c "
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
Expand All @@ -183,6 +174,44 @@ def test_statement_execution_backend_save_table_empty_records():
)


def test_statement_execution_backend_save_table_overwrite_empty_records() -> None:
ws = create_autospec(WorkspaceClient)

ws.statement_execution.execute_statement.return_value = StatementResponse(
status=StatementStatus(state=StatementState.SUCCEEDED)
)

seb = StatementExecutionBackend(ws, "abc")

seb.save_table("a.b.c", [], Bar, mode="overwrite")

ws.statement_execution.execute_statement.assert_has_calls(
[
call(
warehouse_id="abc",
statement="CREATE TABLE IF NOT EXISTS a.b.c "
"(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
call(
warehouse_id="abc",
statement="TRUNCATE TABLE a.b.c",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
]
)


def test_statement_execution_backend_save_table_two_records():
ws = create_autospec(WorkspaceClient)

Expand Down Expand Up @@ -220,7 +249,7 @@ def test_statement_execution_backend_save_table_two_records():
)


def test_statement_execution_backend_save_table_in_batches_of_two():
def test_statement_execution_backend_save_table_append_in_batches_of_two() -> None:
ws = create_autospec(WorkspaceClient)

ws.statement_execution.execute_statement.return_value = StatementResponse(
Expand All @@ -229,7 +258,7 @@ def test_statement_execution_backend_save_table_in_batches_of_two():

seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2)

seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo)
seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="append")

ws.statement_execution.execute_statement.assert_has_calls(
[
Expand Down Expand Up @@ -267,6 +296,53 @@ def test_statement_execution_backend_save_table_in_batches_of_two():
)


def test_statement_execution_backend_save_table_overwrite_in_batches_of_two() -> None:
ws = create_autospec(WorkspaceClient)

ws.statement_execution.execute_statement.return_value = StatementResponse(
status=StatementStatus(state=StatementState.SUCCEEDED)
)

seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2)

seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo, mode="overwrite")

ws.statement_execution.execute_statement.assert_has_calls(
[
mock.call(
warehouse_id="abc",
statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
mock.call(
warehouse_id="abc",
statement="INSERT OVERWRITE a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
mock.call(
warehouse_id="abc",
statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
]
)


def test_runtime_backend_execute():
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
pyspark_sql_session = MagicMock()
Expand Down Expand Up @@ -298,21 +374,53 @@ def test_runtime_backend_fetch():
spark.sql.assert_has_calls(calls)


def test_runtime_backend_save_table():
@pytest.mark.parametrize("mode", ["append", "overwrite"])
def test_runtime_backend_save_table(mode: Literal["append", "overwrite"]) -> None:
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
pyspark_sql_session = MagicMock()
sys.modules["pyspark.sql.session"] = pyspark_sql_session
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()

runtime_backend = RuntimeBackend()

runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, mode=mode)

spark.createDataFrame.assert_called_with(
spark.createDataFrame.assert_called_once_with(
[Foo(first="aaa", second=True), Foo(first="bbb", second=False)],
"first STRING NOT NULL, second BOOLEAN NOT NULL",
)
spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append")
spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode=mode)


def test_runtime_backend_save_table_append_empty_records() -> None:
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
pyspark_sql_session = MagicMock()
sys.modules["pyspark.sql.session"] = pyspark_sql_session
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()

runtime_backend = RuntimeBackend()

runtime_backend.save_table("a.b.c", [], Foo, mode="append")

spark.createDataFrame.assert_not_called()
spark.createDataFrame().write.saveAsTable.assert_not_called()
spark.sql.assert_called_once_with(
"CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA"
)


def test_runtime_backend_save_table_overwrite_empty_records() -> None:
with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}):
pyspark_sql_session = MagicMock()
sys.modules["pyspark.sql.session"] = pyspark_sql_session
spark = pyspark_sql_session.SparkSession.builder.getOrCreate()

runtime_backend = RuntimeBackend()

runtime_backend.save_table("a.b.c", [], Foo, mode="overwrite")

spark.createDataFrame.assert_called_once_with([], "first STRING NOT NULL, second BOOLEAN NOT NULL")
spark.createDataFrame().write.saveAsTable.assert_called_once_with("a.b.c", mode="overwrite")


def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker):
Expand Down Expand Up @@ -427,6 +535,27 @@ def test_mock_backend_save_table_overwrite() -> None:
]


def test_mock_backend_save_table_no_rows() -> None:
mock_backend = MockBackend()

mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
mock_backend.save_table("a.b.c", [], Foo)

assert mock_backend.rows_written_for("a.b.c", mode="append") == [
Row(first="aaa", second=True),
Row(first="bbb", second=False),
]


def test_mock_backend_save_table_overwrite_no_rows() -> None:
mock_backend = MockBackend()

mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo)
mock_backend.save_table("a.b.c", [], Foo)

assert mock_backend.rows_written_for("a.b.c", mode="overwrite") == []


def test_mock_backend_rows_dsl():
rows = MockBackend.rows("foo", "bar")[
[1, 2],
Expand Down
Loading

0 comments on commit 128a0fa

Please sign in to comment.