diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index 1c5147c48f0..2ac5e06d75d 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -46,32 +46,25 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement): :type parsed_statement: ParsedStatement :param parsed_statement: parsed_statement based on the sql query """ - if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT: + if connection.is_closed: + raise ProgrammingError(CONNECTION_CLOSED_ERROR) + statement_type = parsed_statement.client_side_statement_type + if statement_type == ClientSideStatementType.COMMIT: connection.commit() return None - if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN: + if statement_type == ClientSideStatementType.BEGIN: connection.begin() return None - if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK: + if statement_type == ClientSideStatementType.ROLLBACK: connection.rollback() return None - if ( - parsed_statement.client_side_statement_type - == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP - ): - if connection.is_closed: - raise ProgrammingError(CONNECTION_CLOSED_ERROR) + if statement_type == ClientSideStatementType.SHOW_COMMIT_TIMESTAMP: return _get_streamed_result_set( ClientSideStatementType.SHOW_COMMIT_TIMESTAMP.name, TypeCode.TIMESTAMP, connection._transaction.committed, ) - if ( - parsed_statement.client_side_statement_type - == ClientSideStatementType.SHOW_READ_TIMESTAMP - ): - if connection.is_closed: - raise ProgrammingError(CONNECTION_CLOSED_ERROR) + if statement_type == ClientSideStatementType.SHOW_READ_TIMESTAMP: return _get_streamed_result_set( ClientSideStatementType.SHOW_READ_TIMESTAMP.name, TypeCode.TIMESTAMP, @@ -85,5 +78,6 @@ def _get_streamed_result_set(column_name, type_code, column_value): ) result_set = PartialResultSet(metadata=ResultSetMetadata(row_type=struct_type_pb)) - result_set.values.extend([_make_value_pb(column_value)]) + if column_value is not None: + result_set.values.extend([_make_value_pb(column_value)]) return StreamedResultSet(iter([result_set])) diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index e68806d247f..b0828f87888 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -24,10 +24,10 @@ RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE) RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE) RE_SHOW_COMMIT_TIMESTAMP = re.compile( - r"^\s*(SHOW VARIABLE COMMIT_TIMESTAMP)", re.IGNORECASE + r"^\s*(SHOW)\s*(VARIABLE)\s*(COMMIT_TIMESTAMP)", re.IGNORECASE ) RE_SHOW_READ_TIMESTAMP = re.compile( - r"^\s*(SHOW VARIABLE READ_TIMESTAMP)", re.IGNORECASE + r"^\s*(SHOW)\s*(VARIABLE)\s*(READ_TIMESTAMP)", re.IGNORECASE ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index cbcdf8f17c6..a77de993174 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -23,6 +23,7 @@ from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot +from deprecated import deprecated from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -143,9 +144,10 @@ def database(self): return self._database @property + @deprecated( + reason="This method is deprecated. Use spanner_transaction_started method" + ) def inside_transaction(self): - """Deprecated property which won't be supported in future versions. - Please use spanner_transaction_started property instead.""" return ( self._transaction and not self._transaction.committed @@ -310,7 +312,6 @@ def _rerun_previous_statements(self): status, res = transaction.batch_update(statements) if status.code == ABORTED: - self._spanner_transaction_started = False raise Aborted(status.details) retried_checksum = ResultsChecksum() @@ -398,7 +399,7 @@ def begin(self): :raises: :class:`InterfaceError`: if this connection is closed. :raises: :class:`OperationalError`: if there is an existing transaction - that has begin or is running + that has been started """ if self._transaction_begin_marked: raise OperationalError("A transaction has already started") diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index f04aae4b1d0..8189af419e0 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -487,6 +487,10 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. self._row_count = _UNSET_COUNT + if self._result_set.metadata.transaction.read_timestamp is not None: + snapshot._transaction_read_timestamp = ( + self._result_set.metadata.transaction.read_timestamp + ) def _handle_DQL(self, sql, params): if self.connection.database is None: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index baae769d785..1e515bd8e69 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -15,14 +15,9 @@ """Model a set of read-only queries to a database as a snapshot.""" import functools -import itertools import threading from google.protobuf.struct_pb2 import Struct -from google.cloud.spanner_v1 import ( - ExecuteSqlRequest, - PartialResultSet, - ResultSetMetadata, -) +from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1 import TransactionSelector @@ -452,17 +447,11 @@ def execute_sql( if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - return self._get_streamed_result_set( - restart, request, trace_attributes, False - ) + return self._get_streamed_result_set(restart, request, trace_attributes) else: - return self._get_streamed_result_set( - restart, request, trace_attributes, True - ) + return self._get_streamed_result_set(restart, request, trace_attributes) - def _get_streamed_result_set( - self, restart, request, trace_attributes, transaction_id_set - ): + def _get_streamed_result_set(self, restart, request, trace_attributes): iterator = _restart_on_unavailable( restart, request, @@ -474,16 +463,6 @@ def _get_streamed_result_set( self._read_request_count += 1 self._execute_sql_count += 1 - if self._read_only and not transaction_id_set: - peek = next(iterator) - response_pb = PartialResultSet.pb(peek) - response_metadata = ResultSetMetadata.wrap(response_pb.metadata) - if response_metadata.transaction.read_timestamp is not None: - self._transaction_read_timestamp = ( - response_metadata.transaction.read_timestamp - ) - iterator = itertools.chain([peek], iterator) - if self._multi_use: return StreamedResultSet(iterator, source=self) else: diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 9caacc6bf12..78c3a89f053 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -153,13 +153,25 @@ def test_begin_client_side(self, shared_instance, dbapi_database): conn3.close() assert got_rows == [updated_row] - def test_commit_timestamp_client_side(self): + def test_commit_timestamp_client_side_transaction(self): """Test executing SHOW_COMMIT_TIMESTAMP client side statement in a transaction.""" self._cursor.execute( """ INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._cursor.execute("SHOW VARIABLE COMMIT_TIMESTAMP") + got_rows = self._cursor.fetchall() + # As the connection is not committed we will get 0 rows + assert len(got_rows) == 0 + assert len(self._cursor.description) == 1 + + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') """ ) @@ -198,18 +210,33 @@ def test_read_timestamp_client_side(self): transaction.""" self._conn.read_only = True - self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [] + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_1 = self._cursor.fetchall() + self._cursor.execute("SELECT * FROM contacts") - self._conn.commit() + assert self._cursor.fetchall() == [] + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_2 = self._cursor.fetchall() - got_rows = self._cursor.fetchall() - assert len(got_rows) == 1 - assert len(got_rows[0]) == 1 + self._conn.commit() + + self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") + read_timestamp_query_result_3 = self._cursor.fetchall() assert len(self._cursor.description) == 1 assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP" - assert isinstance(got_rows[0][0], DatetimeWithNanoseconds) + + assert ( + read_timestamp_query_result_1 + == read_timestamp_query_result_2 + == read_timestamp_query_result_3 + ) + assert len(read_timestamp_query_result_1) == 1 + assert len(read_timestamp_query_result_1[0]) == 1 + assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds) def test_read_timestamp_client_side_autocommit(self): """Test executing SHOW_READ_TIMESTAMP client side statement in a @@ -225,6 +252,9 @@ def test_read_timestamp_client_side_autocommit(self): ) self._conn.read_only = True self._cursor.execute("SELECT * FROM contacts") + assert self._cursor.fetchall() == [ + (2, "first-name", "last-name", "test.email@domen.ru") + ] self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") got_rows = self._cursor.fetchall() @@ -725,8 +755,10 @@ def test_read_only(self): ReadOnly transactions. """ + self._conn.read_only = True self._cursor.execute("SELECT * FROM contacts") self._conn.commit() + assert self._cursor.fetchall() == [] def test_read_only_dml(self): """ diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 06819c3a3d6..b03c871ca7d 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -52,11 +52,13 @@ def test_classify_stmt(self): ), ("CREATE ROLE parent", StatementType.DDL), ("commit", StatementType.CLIENT_SIDE), - (" commit TRANSACTION ", StatementType.CLIENT_SIDE), + (" commit TRANSACTION ", StatementType.CLIENT_SIDE), ("begin", StatementType.CLIENT_SIDE), ("start", StatementType.CLIENT_SIDE), ("begin transaction", StatementType.CLIENT_SIDE), ("start transaction", StatementType.CLIENT_SIDE), + (" SHOW VARIABLE COMMIT_TIMESTAMP ", StatementType.CLIENT_SIDE), + ("SHOW VARIABLE READ_TIMESTAMP", StatementType.CLIENT_SIDE), ("rollback", StatementType.CLIENT_SIDE), (" rollback TRANSACTION ", StatementType.CLIENT_SIDE), ("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL),