Skip to content

Commit

Permalink
feat: support transaction and request tags in dbapi (#1262)
Browse files Browse the repository at this point in the history
* feat: support transaction and request tags in dbapi

Adds support for setting transaction tags and request tags in dbapi.
This makes these options available to frameworks that depend on
dbapi, like SQLAlchemy and Django.

Towards googleapis/python-spanner-sqlalchemy#525

* test: add test for transaction_tag with read-only tx

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
olavloite and gcf-owl-bot[bot] authored Jan 13, 2025
1 parent d9ee75a commit ee9662f
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 10 deletions.
4 changes: 0 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,3 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test


# Ignore coverage files
.coverage*
35 changes: 32 additions & 3 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs):
self.request_priority = None
self._transaction_begin_marked = False
# whether transaction started at Spanner. This means that we had
# made atleast one call to Spanner.
# made at least one call to Spanner.
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
Expand Down Expand Up @@ -261,6 +261,28 @@ def request_options(self):
self.request_priority = None
return req_opts

@property
def transaction_tag(self):
"""The transaction tag that will be applied to the next read/write
transaction on this `Connection`. This property is automatically cleared
when a new transaction is started.
Returns:
str: The transaction tag that will be applied to the next read/write transaction.
"""
return self._connection_variables.get("transaction_tag", None)

@transaction_tag.setter
def transaction_tag(self, value):
"""Sets the transaction tag for the next read/write transaction on this
`Connection`. This property is automatically cleared when a new transaction
is started.
Args:
value (str): The transaction tag for the next read/write transaction.
"""
self._connection_variables["transaction_tag"] = value

@property
def staleness(self):
"""Current read staleness option value of this `Connection`.
Expand Down Expand Up @@ -340,6 +362,8 @@ def transaction_checkout(self):
if not self.read_only and self._client_transaction_started:
if not self._spanner_transaction_started:
self._transaction = self._session_checkout().transaction()
self._transaction.transaction_tag = self.transaction_tag
self.transaction_tag = None
self._snapshot = None
self._spanner_transaction_started = True
self._transaction.begin()
Expand Down Expand Up @@ -458,7 +482,9 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement: Statement):
def run_statement(
self, statement: Statement, request_options: RequestOptions = None
):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
Expand All @@ -472,6 +498,9 @@ def run_statement(self, statement: Statement):
:param retried: (Optional) Retry the SQL statement if statement
execution failed. Defaults to false.
:type request_options: :class:`RequestOptions`
:param request_options: Request options to use for this statement.
:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:returns: Streamed result set of the statement and a
Expand All @@ -482,7 +511,7 @@ def run_statement(self, statement: Statement):
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
request_options=request_options or self.request_options,
)

@check_not_closed
Expand Down
42 changes: 39 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.merged_result_set import MergedResultSet

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Expand Down Expand Up @@ -97,6 +98,39 @@ def __init__(self, connection):
self._parsed_statement: ParsedStatement = None
self._in_retry_mode = False
self._batch_dml_rows_count = None
self._request_tag = None

@property
def request_tag(self):
"""The request tag that will be applied to the next statement on this
cursor. This property is automatically cleared when a statement is
executed.
Returns:
str: The request tag that will be applied to the next statement on
this cursor.
"""
return self._request_tag

@request_tag.setter
def request_tag(self, value):
"""Sets the request tag for the next statement on this cursor. This
property is automatically cleared when a statement is executed.
Args:
value (str): The request tag for the statement.
"""
self._request_tag = value

@property
def request_options(self):
options = self.connection.request_options
if self._request_tag:
if not options:
options = RequestOptions()
options.request_tag = self._request_tag
self._request_tag = None
return options

@property
def is_closed(self):
Expand Down Expand Up @@ -284,7 +318,7 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
sql,
params=args,
param_types=self._parsed_statement.statement.param_types,
request_options=self.connection.request_options,
request_options=self.request_options,
)
self._result_set = None
else:
Expand Down Expand Up @@ -318,7 +352,9 @@ def _execute_in_rw_transaction(self):
if self.connection._client_transaction_started:
while True:
try:
self._result_set = self.connection.run_statement(statement)
self._result_set = self.connection.run_statement(
statement, self.request_options
)
self._itr = PeekIterator(self._result_set)
return
except Aborted:
Expand Down Expand Up @@ -478,7 +514,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
sql,
params,
get_param_types(params),
request_options=self.connection.request_options,
request_options=self.request_options,
)
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
Expand Down
206 changes: 206 additions & 0 deletions tests/mockserver_tests/test_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TypeCode,
CommitRequest,
)
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_single_result,
)


class TestTags(MockServerTestBase):
@classmethod
def setup_class(cls):
super().setup_class()
add_single_result(
"select name from singers", "name", TypeCode.STRING, [("Some Singer",)]
)

def test_select_autocommit_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_autocommit_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_only_transaction_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.read_only = True
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_only_transaction_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.read_only = True
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_only_transaction_with_transaction_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.read_only = True
connection.transaction_tag = "my_transaction_tag"
self._execute_and_verify_select_singers(connection)
self._execute_and_verify_select_singers(connection)

# Read-only transactions do not support tags, so the transaction_tag is
# also not cleared from the connection when a read-only transaction is
# executed.
self.assertEqual("my_transaction_tag", connection.transaction_tag)

# Read-only transactions do not need to be committed or rolled back on
# Spanner, but dbapi requires this to end the transaction.
connection.commit()
requests = self.spanner_service.requests
self.assertEqual(4, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
# Transaction tags are not supported for read-only transactions.
self.assertEqual("", requests[2].request_options.transaction_tag)
self.assertEqual("", requests[3].request_options.transaction_tag)

def test_select_read_write_transaction_no_tags(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
request = self._execute_and_verify_select_singers(connection)
self.assertEqual("", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_write_transaction_with_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
request = self._execute_and_verify_select_singers(
connection, request_tag="my_tag"
)
self.assertEqual("my_tag", request.request_options.request_tag)
self.assertEqual("", request.request_options.transaction_tag)

def test_select_read_write_transaction_with_transaction_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.transaction_tag = "my_transaction_tag"
# The transaction tag should be included for all statements in the transaction.
self._execute_and_verify_select_singers(connection)
self._execute_and_verify_select_singers(connection)

# The transaction tag was cleared from the connection when the transaction
# was started.
self.assertIsNone(connection.transaction_tag)
# The commit call should also include a transaction tag.
connection.commit()
requests = self.spanner_service.requests
self.assertEqual(5, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))
self.assertEqual(
"my_transaction_tag", requests[2].request_options.transaction_tag
)
self.assertEqual(
"my_transaction_tag", requests[3].request_options.transaction_tag
)
self.assertEqual(
"my_transaction_tag", requests[4].request_options.transaction_tag
)

def test_select_read_write_transaction_with_transaction_and_request_tag(self):
connection = Connection(self.instance, self.database)
connection.autocommit = False
connection.transaction_tag = "my_transaction_tag"
# The transaction tag should be included for all statements in the transaction.
self._execute_and_verify_select_singers(connection, request_tag="my_tag1")
self._execute_and_verify_select_singers(connection, request_tag="my_tag2")

# The transaction tag was cleared from the connection when the transaction
# was started.
self.assertIsNone(connection.transaction_tag)
# The commit call should also include a transaction tag.
connection.commit()
requests = self.spanner_service.requests
self.assertEqual(5, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))
self.assertEqual(
"my_transaction_tag", requests[2].request_options.transaction_tag
)
self.assertEqual("my_tag1", requests[2].request_options.request_tag)
self.assertEqual(
"my_transaction_tag", requests[3].request_options.transaction_tag
)
self.assertEqual("my_tag2", requests[3].request_options.request_tag)
self.assertEqual(
"my_transaction_tag", requests[4].request_options.transaction_tag
)

def test_request_tag_is_cleared(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.request_tag = "my_tag"
cursor.execute("select name from singers")
# This query will not have a request tag.
cursor.execute("select name from singers")
requests = self.spanner_service.requests
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
self.assertEqual("my_tag", requests[1].request_options.request_tag)
self.assertEqual("", requests[2].request_options.request_tag)

def _execute_and_verify_select_singers(
self, connection: Connection, request_tag: str = "", transaction_tag: str = ""
) -> ExecuteSqlRequest:
with connection.cursor() as cursor:
if request_tag:
cursor.request_tag = request_tag
cursor.execute("select name from singers")
result_list = cursor.fetchall()
for row in result_list:
self.assertEqual("Some Singer", row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
return next(
request
for request in requests
if isinstance(request, ExecuteSqlRequest)
and request.sql == "select name from singers"
)

0 comments on commit ee9662f

Please sign in to comment.