Skip to content

Commit

Permalink
Incorporated comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Dec 5, 2023
1 parent 8b63b9c commit f5b704b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 55 deletions.
26 changes: 10 additions & 16 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,25 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
if statement_type == ClientSideStatementType.COMMIT:
connection.commit()
return None
if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN:
if statement_type == ClientSideStatementType.BEGIN:
connection.begin()
return None
if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK:
if statement_type == ClientSideStatementType.ROLLBACK:
connection.rollback()
return None
if (
parsed_statement.client_side_statement_type
== ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
):
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP:
return _get_streamed_result_set(
ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name,
TypeCode.TIMESTAMP,
connection._transaction.committed,
)
if (
parsed_statement.client_side_statement_type
== ClientSideStatementType.SHOW_READ_TIMESTAMP
):
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP:
return _get_streamed_result_set(
ClientSideStatementType.SHOW_READ_TIMESTAMP.name,
TypeCode.TIMESTAMP,
Expand All @@ -85,5 +78,6 @@ def _get_streamed_result_set(column_name, type_code, column_value):
)

result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb))
result_set.values.extend([_make_value_pb(column_value)])
if column_value is not None:
result_set.values.extend([_make_value_pb(column_value)])
return StreamedResultSet(iter([result_set]))
4 changes: 2 additions & 2 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE)
RE_SHOW_COMMIT_TIMESTAMP = re.compile(
r"^\s*(SHOW VARIABLE COMMIT_TIMESTAMP)", re.IGNORECASE
r"^\s*(SHOW)\s*(VARIABLE)\s*(COMMIT_TIMESTAMP)", re.IGNORECASE
)
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW VARIABLE READ_TIMESTAMP)", re.IGNORECASE
r"^\s*(SHOW)\s*(VARIABLE)\s*(READ_TIMESTAMP)", re.IGNORECASE
)


Expand Down
9 changes: 5 additions & 4 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
Expand Down Expand Up @@ -143,9 +144,10 @@ def database(self):
return self._database

@property
@deprecated(
reason="This method is deprecated. Use spanner_transaction_started method"
)
def inside_transaction(self):
"""Deprecated property which won't be supported in future versions.
Please use spanner_transaction_started property instead."""
return (
self._transaction
and not self._transaction.committed
Expand Down Expand Up @@ -310,7 +312,6 @@ def _rerun_previous_statements(self):
status, res = transaction.batch_update(statements)

if status.code == ABORTED:
self._spanner_transaction_started = False
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
Expand Down Expand Up @@ -398,7 +399,7 @@ def begin(self):
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if there is an existing transaction
that has begin or is running
that has been started
"""
if self._transaction_begin_marked:
raise OperationalError("A transaction has already started")
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,10 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
# Unfortunately, Spanner doesn't seem to send back
# information about the number of rows available.
self._row_count = _UNSET_COUNT
if self._result_set.metadata.transaction.read_timestamp is not None:
snapshot._transaction_read_timestamp = (
self._result_set.metadata.transaction.read_timestamp
)

def _handle_DQL(self, sql, params):
if self.connection.database is None:
Expand Down
29 changes: 4 additions & 25 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@
"""Model a set of read-only queries to a database as a snapshot."""

import functools
import itertools
import threading
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1 import (
ExecuteSqlRequest,
PartialResultSet,
ResultSetMetadata,
)
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import ReadRequest
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1 import TransactionSelector
Expand Down Expand Up @@ -452,17 +447,11 @@ def execute_sql(
if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
with self._lock:
return self._get_streamed_result_set(
restart, request, trace_attributes, False
)
return self._get_streamed_result_set(restart, request, trace_attributes)
else:
return self._get_streamed_result_set(
restart, request, trace_attributes, True
)
return self._get_streamed_result_set(restart, request, trace_attributes)

def _get_streamed_result_set(
self, restart, request, trace_attributes, transaction_id_set
):
def _get_streamed_result_set(self, restart, request, trace_attributes):
iterator = _restart_on_unavailable(
restart,
request,
Expand All @@ -474,16 +463,6 @@ def _get_streamed_result_set(
self._read_request_count += 1
self._execute_sql_count += 1

if self._read_only and not transaction_id_set:
peek = next(iterator)
response_pb = PartialResultSet.pb(peek)
response_metadata = ResultSetMetadata.wrap(response_pb.metadata)
if response_metadata.transaction.read_timestamp is not None:
self._transaction_read_timestamp = (
response_metadata.transaction.read_timestamp
)
iterator = itertools.chain([peek], iterator)

if self._multi_use:
return StreamedResultSet(iterator, source=self)
else:
Expand Down
46 changes: 39 additions & 7 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,25 @@ def test_begin_client_side(self, shared_instance, dbapi_database):
conn3.close()
assert got_rows == [updated_row]

def test_commit_timestamp_client_side(self):
def test_commit_timestamp_client_side_transaction(self):
"""Test executing SHOW_COMMIT_TIMESTAMP client side statement in a
transaction."""

self._cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru')
"""
)
self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP")
got_rows = self._cursor.fetchall()
# As the connection is not committed we will get 0 rows
assert len(got_rows) == 0
assert len(self._cursor.description) == 1

self._cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru')
"""
)
Expand Down Expand Up @@ -198,18 +210,33 @@ def test_read_timestamp_client_side(self):
transaction."""

self._conn.read_only = True

self._cursor.execute("SELECT * FROM contacts")
assert self._cursor.fetchall() == []

self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP")
read_timestamp_query_result_1 = self._cursor.fetchall()

self._cursor.execute("SELECT * FROM contacts")
self._conn.commit()
assert self._cursor.fetchall() == []

self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP")
read_timestamp_query_result_2 = self._cursor.fetchall()

got_rows = self._cursor.fetchall()
assert len(got_rows) == 1
assert len(got_rows[0]) == 1
self._conn.commit()

self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP")
read_timestamp_query_result_3 = self._cursor.fetchall()
assert len(self._cursor.description) == 1
assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP"
assert isinstance(got_rows[0][0], DatetimeWithNanoseconds)

assert (
read_timestamp_query_result_1
== read_timestamp_query_result_2
== read_timestamp_query_result_3
)
assert len(read_timestamp_query_result_1) == 1
assert len(read_timestamp_query_result_1[0]) == 1
assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds)

def test_read_timestamp_client_side_autocommit(self):
"""Test executing SHOW_READ_TIMESTAMP client side statement in a
Expand All @@ -225,6 +252,9 @@ def test_read_timestamp_client_side_autocommit(self):
)
self._conn.read_only = True
self._cursor.execute("SELECT * FROM contacts")
assert self._cursor.fetchall() == [
(2, "first-name", "last-name", "test.email@domen.ru")
]
self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP")

got_rows = self._cursor.fetchall()
Expand Down Expand Up @@ -725,8 +755,10 @@ def test_read_only(self):
ReadOnly transactions.
"""

self._conn.read_only = True
self._cursor.execute("SELECT * FROM contacts")
self._conn.commit()
assert self._cursor.fetchall() == []

def test_read_only_dml(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def test_classify_stmt(self):
),
("CREATE ROLE parent", StatementType.DDL),
("commit", StatementType.CLIENT_SIDE),
(" commit TRANSACTION ", StatementType.CLIENT_SIDE),
(" commit TRANSACTION ", StatementType.CLIENT_SIDE),
("begin", StatementType.CLIENT_SIDE),
("start", StatementType.CLIENT_SIDE),
("begin transaction", StatementType.CLIENT_SIDE),
("start transaction", StatementType.CLIENT_SIDE),
(" SHOW VARIABLE COMMIT_TIMESTAMP ", StatementType.CLIENT_SIDE),
("SHOW VARIABLE READ_TIMESTAMP", StatementType.CLIENT_SIDE),
("rollback", StatementType.CLIENT_SIDE),
(" rollback TRANSACTION ", StatementType.CLIENT_SIDE),
("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL),
Expand Down

0 comments on commit f5b704b

Please sign in to comment.