diff --git a/.gitignore b/.gitignore index 4797754726..d083ea1ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -62,7 +62,3 @@ system_tests/local_test_setup # Make sure a generated file isn't accidentally committed. pylintrc pylintrc.test - - -# Ignore coverage files -.coverage* diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index cec6c64dac..c2aa385d2a 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -113,7 +113,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs): self.request_priority = None self._transaction_begin_marked = False # whether transaction started at Spanner. This means that we had - # made atleast one call to Spanner. + # made at least one call to Spanner. self._spanner_transaction_started = False self._batch_mode = BatchMode.NONE self._batch_dml_executor: BatchDmlExecutor = None @@ -261,6 +261,28 @@ def request_options(self): self.request_priority = None return req_opts + @property + def transaction_tag(self): + """The transaction tag that will be applied to the next read/write + transaction on this `Connection`. This property is automatically cleared + when a new transaction is started. + + Returns: + str: The transaction tag that will be applied to the next read/write transaction. + """ + return self._connection_variables.get("transaction_tag", None) + + @transaction_tag.setter + def transaction_tag(self, value): + """Sets the transaction tag for the next read/write transaction on this + `Connection`. This property is automatically cleared when a new transaction + is started. + + Args: + value (str): The transaction tag for the next read/write transaction. + """ + self._connection_variables["transaction_tag"] = value + @property def staleness(self): """Current read staleness option value of this `Connection`. @@ -340,6 +362,8 @@ def transaction_checkout(self): if not self.read_only and self._client_transaction_started: if not self._spanner_transaction_started: self._transaction = self._session_checkout().transaction() + self._transaction.transaction_tag = self.transaction_tag + self.transaction_tag = None self._snapshot = None self._spanner_transaction_started = True self._transaction.begin() @@ -458,7 +482,9 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() - def run_statement(self, statement: Statement): + def run_statement( + self, statement: Statement, request_options: RequestOptions = None + ): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In @@ -472,6 +498,9 @@ def run_statement(self, statement: Statement): :param retried: (Optional) Retry the SQL statement if statement execution failed. Defaults to false. + :type request_options: :class:`RequestOptions` + :param request_options: Request options to use for this statement. + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` :returns: Streamed result set of the statement and a @@ -482,7 +511,7 @@ def run_statement(self, statement: Statement): statement.sql, statement.params, param_types=statement.param_types, - request_options=self.request_options, + request_options=request_options or self.request_options, ) @check_not_closed diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 8b4170e3f2..a72a8e9de1 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -50,6 +50,7 @@ from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets +from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.merged_result_set import MergedResultSet ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) @@ -97,6 +98,39 @@ def __init__(self, connection): self._parsed_statement: ParsedStatement = None self._in_retry_mode = False self._batch_dml_rows_count = None + self._request_tag = None + + @property + def request_tag(self): + """The request tag that will be applied to the next statement on this + cursor. This property is automatically cleared when a statement is + executed. + + Returns: + str: The request tag that will be applied to the next statement on + this cursor. + """ + return self._request_tag + + @request_tag.setter + def request_tag(self, value): + """Sets the request tag for the next statement on this cursor. This + property is automatically cleared when a statement is executed. + + Args: + value (str): The request tag for the statement. + """ + self._request_tag = value + + @property + def request_options(self): + options = self.connection.request_options + if self._request_tag: + if not options: + options = RequestOptions() + options.request_tag = self._request_tag + self._request_tag = None + return options @property def is_closed(self): @@ -284,7 +318,7 @@ def _execute(self, sql, args=None, call_from_execute_many=False): sql, params=args, param_types=self._parsed_statement.statement.param_types, - request_options=self.connection.request_options, + request_options=self.request_options, ) self._result_set = None else: @@ -318,7 +352,9 @@ def _execute_in_rw_transaction(self): if self.connection._client_transaction_started: while True: try: - self._result_set = self.connection.run_statement(statement) + self._result_set = self.connection.run_statement( + statement, self.request_options + ) self._itr = PeekIterator(self._result_set) return except Aborted: @@ -478,7 +514,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): sql, params, get_param_types(params), - request_options=self.connection.request_options, + request_options=self.request_options, ) # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. diff --git a/tests/mockserver_tests/test_tags.py b/tests/mockserver_tests/test_tags.py new file mode 100644 index 0000000000..c84d69b7bd --- /dev/null +++ b/tests/mockserver_tests/test_tags.py @@ -0,0 +1,206 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_v1 import ( + BatchCreateSessionsRequest, + ExecuteSqlRequest, + BeginTransactionRequest, + TypeCode, + CommitRequest, +) +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_single_result, +) + + +class TestTags(MockServerTestBase): + @classmethod + def setup_class(cls): + super().setup_class() + add_single_result( + "select name from singers", "name", TypeCode.STRING, [("Some Singer",)] + ) + + def test_select_autocommit_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_autocommit_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_with_transaction_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + connection.transaction_tag = "my_transaction_tag" + self._execute_and_verify_select_singers(connection) + self._execute_and_verify_select_singers(connection) + + # Read-only transactions do not support tags, so the transaction_tag is + # also not cleared from the connection when a read-only transaction is + # executed. + self.assertEqual("my_transaction_tag", connection.transaction_tag) + + # Read-only transactions do not need to be committed or rolled back on + # Spanner, but dbapi requires this to end the transaction. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(4, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + # Transaction tags are not supported for read-only transactions. + self.assertEqual("", requests[2].request_options.transaction_tag) + self.assertEqual("", requests[3].request_options.transaction_tag) + + def test_select_read_write_transaction_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_write_transaction_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_write_transaction_with_transaction_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.transaction_tag = "my_transaction_tag" + # The transaction tag should be included for all statements in the transaction. + self._execute_and_verify_select_singers(connection) + self._execute_and_verify_select_singers(connection) + + # The transaction tag was cleared from the connection when the transaction + # was started. + self.assertIsNone(connection.transaction_tag) + # The commit call should also include a transaction tag. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(5, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[4], CommitRequest)) + self.assertEqual( + "my_transaction_tag", requests[2].request_options.transaction_tag + ) + self.assertEqual( + "my_transaction_tag", requests[3].request_options.transaction_tag + ) + self.assertEqual( + "my_transaction_tag", requests[4].request_options.transaction_tag + ) + + def test_select_read_write_transaction_with_transaction_and_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.transaction_tag = "my_transaction_tag" + # The transaction tag should be included for all statements in the transaction. + self._execute_and_verify_select_singers(connection, request_tag="my_tag1") + self._execute_and_verify_select_singers(connection, request_tag="my_tag2") + + # The transaction tag was cleared from the connection when the transaction + # was started. + self.assertIsNone(connection.transaction_tag) + # The commit call should also include a transaction tag. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(5, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[4], CommitRequest)) + self.assertEqual( + "my_transaction_tag", requests[2].request_options.transaction_tag + ) + self.assertEqual("my_tag1", requests[2].request_options.request_tag) + self.assertEqual( + "my_transaction_tag", requests[3].request_options.transaction_tag + ) + self.assertEqual("my_tag2", requests[3].request_options.request_tag) + self.assertEqual( + "my_transaction_tag", requests[4].request_options.transaction_tag + ) + + def test_request_tag_is_cleared(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + with connection.cursor() as cursor: + cursor.request_tag = "my_tag" + cursor.execute("select name from singers") + # This query will not have a request tag. + cursor.execute("select name from singers") + requests = self.spanner_service.requests + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertEqual("my_tag", requests[1].request_options.request_tag) + self.assertEqual("", requests[2].request_options.request_tag) + + def _execute_and_verify_select_singers( + self, connection: Connection, request_tag: str = "", transaction_tag: str = "" + ) -> ExecuteSqlRequest: + with connection.cursor() as cursor: + if request_tag: + cursor.request_tag = request_tag + cursor.execute("select name from singers") + result_list = cursor.fetchall() + for row in result_list: + self.assertEqual("Some Singer", row[0]) + self.assertEqual(1, len(result_list)) + requests = self.spanner_service.requests + return next( + request + for request in requests + if isinstance(request, ExecuteSqlRequest) + and request.sql == "select name from singers" + )