Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for save_table(..., mode="overwrite") to StatementExecutionBackend #74

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,14 @@ def fetch(self, sql: str) -> Iterator[Row]:
return self._sql.fetch_all(sql)

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"):
if mode == "overwrite":
msg = "Overwrite mode is not yet supported"
raise NotImplementedError(msg)
rows = self._filter_none_rows(rows, klass)
self.create_table(full_name, klass)
if len(rows) == 0:
return
fields = dataclasses.fields(klass)
field_names = [f.name for f in fields]
if mode == "overwrite":
self.execute(f"TRUNCATE TABLE {full_name}")
for i in range(0, len(rows), self._max_records_per_batch):
Comment on lines +155 to 157
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd need to properly support overwrites in the raw sql shape to keep the same'ish semantics as Spark:

  1. INSERT INTO {full_name}_tmp ...
  2. CREATE OR REPLACE TABLE {full_name} AS SELECT * FROM {full_name}_tmp
  3. DROP TABLE {full_name}_tmp

otherwise the failure of overwrite will leave the table in a corrupt state.

batch = rows[i : i + self._max_records_per_batch]
vals = "), (".join(self._row_to_sql(r, fields) for r in batch)
Expand Down Expand Up @@ -283,10 +282,9 @@ def fetch(self, sql) -> Iterator[Row]:
return iter(rows)

def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"):
if mode == "overwrite":
msg = "Overwrite mode is not yet supported"
raise NotImplementedError(msg)
rows = self._filter_none_rows(rows, klass)
if mode == "overwrite":
self._save_table = []
if klass.__class__ == type:
row_factory = self._row_factory(klass)
rows = [row_factory(*dataclasses.astuple(r)) for r in rows]
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ def test_deploys_database(ws, env_or_skip, make_random):
rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some"))

assert rows == [Row(name="abc", id=1)]


def test_overwrite(ws, env_or_skip, make_random):
schema = "default"
sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"))

sql_backend.save_table(f"{schema}.foo", [views.Foo("abc", True)], views.Foo, "append")
sql_backend.save_table(f"{schema}.foo", [views.Foo("xyz", True)], views.Foo, "overwrite")
rows = list(sql_backend.fetch(f"SELECT * FROM {schema}.some"))

assert rows == [Row(name="xyz", id=1)]
58 changes: 54 additions & 4 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,47 @@ def test_statement_execution_backend_fetch_happy():
assert [Row(id=1), Row(id=2), Row(id=3)] == result


def test_statement_execution_backend_save_table_overwrite(mocker):
seb = StatementExecutionBackend(mocker.Mock(), "abc")
with pytest.raises(NotImplementedError):
seb.save_table("a.b.c", [1, 2, 3], Bar, mode="overwrite")
def test_statement_execution_backend_save_table_overwrite_empty_table():
ws = create_autospec(WorkspaceClient)
ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse(
status=StatementStatus(state=StatementState.SUCCEEDED)
)
seb = StatementExecutionBackend(ws, "abc")
seb.save_table("a.b.c", [Baz("1")], Baz, mode="overwrite")
ws.statement_execution.execute_statement.assert_has_calls(
[
mock.call(
warehouse_id="abc",
statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second STRING) USING DELTA",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
mock.call(
warehouse_id="abc",
statement="TRUNCATE TABLE a.b.c",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
mock.call(
warehouse_id="abc",
statement="INSERT INTO a.b.c (first, second) VALUES ('1', NULL)",
catalog=None,
schema=None,
disposition=None,
format=Format.JSON_ARRAY,
byte_limit=None,
wait_timeout=None,
),
]
)


def test_statement_execution_backend_save_table_empty_records():
Expand Down Expand Up @@ -357,3 +394,16 @@ def test_mock_backend_rows_dsl():
Row(foo=1, bar=2),
Row(foo=3, bar=4),
]


def test_mock_backend_overwrite():
mock_backend = MockBackend()
mock_backend.save_table("a.b.c", [Foo("a1", True), Foo("c2", False)], Foo, "append")
mock_backend.save_table("a.b.c", [Foo("aa", True), Foo("bb", False)], Foo, "overwrite")
mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo, "overwrite")

assert mock_backend.rows_written_for("a.b.c", "append") == []
assert mock_backend.rows_written_for("a.b.c", "overwrite") == [
Row(first="aaa", second=True),
Row(first="bbb", second=False),
]
Loading