From ee9662f57dbb730afb08b9b9829e4e19bda5e69a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 13 Jan 2025 14:09:23 +0100 Subject: [PATCH] feat: support transaction and request tags in dbapi (#1262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support transaction and request tags in dbapi Adds support for setting transaction tags and request tags in dbapi. This makes these options available to frameworks that depend on dbapi, like SQLAlchemy and Django. Towards https://github.com/googleapis/python-spanner-sqlalchemy/issues/525 * test: add test for transaction_tag with read-only tx * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot --- .gitignore | 4 - google/cloud/spanner_dbapi/connection.py | 35 +++- google/cloud/spanner_dbapi/cursor.py | 42 ++++- tests/mockserver_tests/test_tags.py | 206 +++++++++++++++++++++++ 4 files changed, 277 insertions(+), 10 deletions(-) create mode 100644 tests/mockserver_tests/test_tags.py diff --git a/.gitignore b/.gitignore index 4797754726..d083ea1ddc 100644 --- a/.gitignore +++ b/.gitignore @@ -62,7 +62,3 @@ system_tests/local_test_setup # Make sure a generated file isn't accidentally committed. pylintrc pylintrc.test - - -# Ignore coverage files -.coverage* diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index cec6c64dac..c2aa385d2a 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -113,7 +113,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs): self.request_priority = None self._transaction_begin_marked = False # whether transaction started at Spanner. This means that we had - # made atleast one call to Spanner. + # made at least one call to Spanner. self._spanner_transaction_started = False self._batch_mode = BatchMode.NONE self._batch_dml_executor: BatchDmlExecutor = None @@ -261,6 +261,28 @@ def request_options(self): self.request_priority = None return req_opts + @property + def transaction_tag(self): + """The transaction tag that will be applied to the next read/write + transaction on this `Connection`. This property is automatically cleared + when a new transaction is started. + + Returns: + str: The transaction tag that will be applied to the next read/write transaction. + """ + return self._connection_variables.get("transaction_tag", None) + + @transaction_tag.setter + def transaction_tag(self, value): + """Sets the transaction tag for the next read/write transaction on this + `Connection`. This property is automatically cleared when a new transaction + is started. + + Args: + value (str): The transaction tag for the next read/write transaction. + """ + self._connection_variables["transaction_tag"] = value + @property def staleness(self): """Current read staleness option value of this `Connection`. @@ -340,6 +362,8 @@ def transaction_checkout(self): if not self.read_only and self._client_transaction_started: if not self._spanner_transaction_started: self._transaction = self._session_checkout().transaction() + self._transaction.transaction_tag = self.transaction_tag + self.transaction_tag = None self._snapshot = None self._spanner_transaction_started = True self._transaction.begin() @@ -458,7 +482,9 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() - def run_statement(self, statement: Statement): + def run_statement( + self, statement: Statement, request_options: RequestOptions = None + ): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In @@ -472,6 +498,9 @@ def run_statement(self, statement: Statement): :param retried: (Optional) Retry the SQL statement if statement execution failed. Defaults to false. + :type request_options: :class:`RequestOptions` + :param request_options: Request options to use for this statement. + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` :returns: Streamed result set of the statement and a @@ -482,7 +511,7 @@ def run_statement(self, statement: Statement): statement.sql, statement.params, param_types=statement.param_types, - request_options=self.request_options, + request_options=request_options or self.request_options, ) @check_not_closed diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 8b4170e3f2..a72a8e9de1 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -50,6 +50,7 @@ from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets +from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.merged_result_set import MergedResultSet ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) @@ -97,6 +98,39 @@ def __init__(self, connection): self._parsed_statement: ParsedStatement = None self._in_retry_mode = False self._batch_dml_rows_count = None + self._request_tag = None + + @property + def request_tag(self): + """The request tag that will be applied to the next statement on this + cursor. This property is automatically cleared when a statement is + executed. + + Returns: + str: The request tag that will be applied to the next statement on + this cursor. + """ + return self._request_tag + + @request_tag.setter + def request_tag(self, value): + """Sets the request tag for the next statement on this cursor. This + property is automatically cleared when a statement is executed. + + Args: + value (str): The request tag for the statement. + """ + self._request_tag = value + + @property + def request_options(self): + options = self.connection.request_options + if self._request_tag: + if not options: + options = RequestOptions() + options.request_tag = self._request_tag + self._request_tag = None + return options @property def is_closed(self): @@ -284,7 +318,7 @@ def _execute(self, sql, args=None, call_from_execute_many=False): sql, params=args, param_types=self._parsed_statement.statement.param_types, - request_options=self.connection.request_options, + request_options=self.request_options, ) self._result_set = None else: @@ -318,7 +352,9 @@ def _execute_in_rw_transaction(self): if self.connection._client_transaction_started: while True: try: - self._result_set = self.connection.run_statement(statement) + self._result_set = self.connection.run_statement( + statement, self.request_options + ) self._itr = PeekIterator(self._result_set) return except Aborted: @@ -478,7 +514,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): sql, params, get_param_types(params), - request_options=self.connection.request_options, + request_options=self.request_options, ) # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. diff --git a/tests/mockserver_tests/test_tags.py b/tests/mockserver_tests/test_tags.py new file mode 100644 index 0000000000..c84d69b7bd --- /dev/null +++ b/tests/mockserver_tests/test_tags.py @@ -0,0 +1,206 @@ +# Copyright 2024 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_v1 import ( + BatchCreateSessionsRequest, + ExecuteSqlRequest, + BeginTransactionRequest, + TypeCode, + CommitRequest, +) +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_single_result, +) + + +class TestTags(MockServerTestBase): + @classmethod + def setup_class(cls): + super().setup_class() + add_single_result( + "select name from singers", "name", TypeCode.STRING, [("Some Singer",)] + ) + + def test_select_autocommit_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_autocommit_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_only_transaction_with_transaction_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + connection.transaction_tag = "my_transaction_tag" + self._execute_and_verify_select_singers(connection) + self._execute_and_verify_select_singers(connection) + + # Read-only transactions do not support tags, so the transaction_tag is + # also not cleared from the connection when a read-only transaction is + # executed. + self.assertEqual("my_transaction_tag", connection.transaction_tag) + + # Read-only transactions do not need to be committed or rolled back on + # Spanner, but dbapi requires this to end the transaction. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(4, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + # Transaction tags are not supported for read-only transactions. + self.assertEqual("", requests[2].request_options.transaction_tag) + self.assertEqual("", requests[3].request_options.transaction_tag) + + def test_select_read_write_transaction_no_tags(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + request = self._execute_and_verify_select_singers(connection) + self.assertEqual("", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_write_transaction_with_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + request = self._execute_and_verify_select_singers( + connection, request_tag="my_tag" + ) + self.assertEqual("my_tag", request.request_options.request_tag) + self.assertEqual("", request.request_options.transaction_tag) + + def test_select_read_write_transaction_with_transaction_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.transaction_tag = "my_transaction_tag" + # The transaction tag should be included for all statements in the transaction. + self._execute_and_verify_select_singers(connection) + self._execute_and_verify_select_singers(connection) + + # The transaction tag was cleared from the connection when the transaction + # was started. + self.assertIsNone(connection.transaction_tag) + # The commit call should also include a transaction tag. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(5, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[4], CommitRequest)) + self.assertEqual( + "my_transaction_tag", requests[2].request_options.transaction_tag + ) + self.assertEqual( + "my_transaction_tag", requests[3].request_options.transaction_tag + ) + self.assertEqual( + "my_transaction_tag", requests[4].request_options.transaction_tag + ) + + def test_select_read_write_transaction_with_transaction_and_request_tag(self): + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.transaction_tag = "my_transaction_tag" + # The transaction tag should be included for all statements in the transaction. + self._execute_and_verify_select_singers(connection, request_tag="my_tag1") + self._execute_and_verify_select_singers(connection, request_tag="my_tag2") + + # The transaction tag was cleared from the connection when the transaction + # was started. + self.assertIsNone(connection.transaction_tag) + # The commit call should also include a transaction tag. + connection.commit() + requests = self.spanner_service.requests + self.assertEqual(5, len(requests)) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[3], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[4], CommitRequest)) + self.assertEqual( + "my_transaction_tag", requests[2].request_options.transaction_tag + ) + self.assertEqual("my_tag1", requests[2].request_options.request_tag) + self.assertEqual( + "my_transaction_tag", requests[3].request_options.transaction_tag + ) + self.assertEqual("my_tag2", requests[3].request_options.request_tag) + self.assertEqual( + "my_transaction_tag", requests[4].request_options.transaction_tag + ) + + def test_request_tag_is_cleared(self): + connection = Connection(self.instance, self.database) + connection.autocommit = True + with connection.cursor() as cursor: + cursor.request_tag = "my_tag" + cursor.execute("select name from singers") + # This query will not have a request tag. + cursor.execute("select name from singers") + requests = self.spanner_service.requests + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + self.assertEqual("my_tag", requests[1].request_options.request_tag) + self.assertEqual("", requests[2].request_options.request_tag) + + def _execute_and_verify_select_singers( + self, connection: Connection, request_tag: str = "", transaction_tag: str = "" + ) -> ExecuteSqlRequest: + with connection.cursor() as cursor: + if request_tag: + cursor.request_tag = request_tag + cursor.execute("select name from singers") + result_list = cursor.fetchall() + for row in result_list: + self.assertEqual("Some Singer", row[0]) + self.assertEqual(1, len(result_list)) + requests = self.spanner_service.requests + return next( + request + for request in requests + if isinstance(request, ExecuteSqlRequest) + and request.sql == "select name from singers" + )