Skip to content

Commit

Permalink
fix: Fixing and refactoring transaction retry logic in dbapi
Browse files Browse the repository at this point in the history
  • Loading branch information
ankiaga committed Dec 15, 2023
1 parent 7a92315 commit 22633c7
Show file tree
Hide file tree
Showing 9 changed files with 560 additions and 563 deletions.
22 changes: 10 additions & 12 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from enum import Enum
from typing import TYPE_CHECKING, List
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
StatementType,
Expand All @@ -25,6 +24,9 @@
from google.rpc.code_pb2 import ABORTED, OK
from google.api_core.exceptions import Aborted

from google.cloud.spanner_dbapi.transaction_helper import (
_get_batch_statements_result_checksum,
)
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,6 +83,7 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
from google.cloud.spanner_dbapi import OperationalError

connection = cursor.connection
transaction_helper = connection._transaction_helper
many_result_set = StreamedManyResultSets()
statements_tuple = []
for statement in statements:
Expand All @@ -90,28 +93,23 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
retried = False
while True:
try:
transaction = connection.transaction_checkout()
status, res = transaction.batch_update(statements_tuple)
many_result_set.add_iter(res)
res_checksum = ResultsChecksum()
res_checksum.consume_result(res)
res_checksum.consume_result(status.code)
if not retried:
connection._statements.append((statements, res_checksum))
cursor._row_count = sum([max(val, 0) for val in res])

if status.code == ABORTED:
connection._transaction = None
raise Aborted(status.message)
elif status.code != OK:
raise OperationalError(status.message)

checksum = _get_batch_statements_result_checksum(res, status.code)
many_result_set.add_iter(res)
transaction_helper._batch_statements_list.append((statements, checksum))
cursor._row_count = sum([max(val, 0) for val in res])
return many_result_set
except Aborted:
connection.retry_transaction()
retried = True
transaction_helper.retry_transaction()


def _do_batch_update(transaction, statements):
Expand Down
108 changes: 13 additions & 95 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@
# limitations under the License.

"""DB-API Connection for the Google Cloud Spanner."""
import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
from google.cloud.spanner_dbapi.transaction_helper import TransactionHelper
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1.snapshot import Snapshot
from deprecated import deprecated

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import (
InterfaceError,
Expand All @@ -37,13 +34,10 @@
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

from google.rpc.code_pb2 import ABORTED


CLIENT_TRANSACTION_NOT_STARTED_WARNING = (
"This method is non-operational as a transaction has not been started."
)
MAX_INTERNAL_RETRIES = 50


def check_not_closed(function):
Expand Down Expand Up @@ -99,9 +93,6 @@ def __init__(self, instance, database=None, read_only=False):
self._transaction = None
self._session = None
self._snapshot = None
# SQL statements, which were executed
# within the current transaction
self._statements = []

self.is_closed = False
self._autocommit = False
Expand All @@ -118,6 +109,7 @@ def __init__(self, instance, database=None, read_only=False):
self._spanner_transaction_started = False
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionHelper(self)

@property
def autocommit(self):
Expand Down Expand Up @@ -281,76 +273,6 @@ def _release_session(self):
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
"""Retry the aborted transaction.
All the statements executed in the original transaction
will be re-executed in new one. Results checksums of the
original statements and the retried ones will be compared.
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
If results checksum of the retried statement is
not equal to the checksum of the original one.
"""
attempt = 0
while True:
self._spanner_transaction_started = False
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise

try:
self._rerun_previous_statements()
break
except Aborted as exc:
delay = _get_retry_delay(exc.errors[0], attempt)
if delay:
time.sleep(delay)

def _rerun_previous_statements(self):
"""
Helper to run all the remembered statements
from the last transaction.
"""
for statement in self._statements:
if isinstance(statement, list):
statements, checksum = statement

transaction = self.transaction_checkout()
statements_tuple = []
for single_statement in statements:
statements_tuple.append(single_statement.get_tuple())
status, res = transaction.batch_update(statements_tuple)

if status.code == ABORTED:
raise Aborted(status.details)

retried_checksum = ResultsChecksum()
retried_checksum.consume_result(res)
retried_checksum.consume_result(status.code)

_compare_checksums(checksum, retried_checksum)
else:
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)

_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Expand Down Expand Up @@ -443,11 +365,12 @@ def commit(self):
if self._spanner_transaction_started and not self._read_only:
self._transaction.commit()
except Aborted:
self.retry_transaction()
self._transaction_helper.retry_transaction()
self.commit()
finally:
self._release_session()
self._statements = []
self._transaction_helper._single_statements = []
self._transaction_helper._batch_statements_list = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

Expand All @@ -467,7 +390,8 @@ def rollback(self):
self._transaction.rollback()
finally:
self._release_session()
self._statements = []
self._transaction_helper._single_statements = []
self._transaction_helper._batch_statements_list = []
self._transaction_begin_marked = False
self._spanner_transaction_started = False

Expand All @@ -486,7 +410,7 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement: Statement, retried=False):
def run_statement(self, statement: Statement):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
Expand All @@ -506,17 +430,11 @@ def run_statement(self, statement: Statement, retried=False):
checksum of this statement results.
"""
transaction = self.transaction_checkout()
if not retried:
self._statements.append(statement)

return (
transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
),
ResultsChecksum() if retried else statement.checksum,
return transaction.execute_sql(
statement.sql,
statement.params,
param_types=statement.param_types,
request_options=self.request_options,
)

@check_not_closed
Expand Down
51 changes: 17 additions & 34 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Database cursor for Google Cloud Spanner DB API."""

import itertools
from collections import namedtuple

import sqlparse
Expand Down Expand Up @@ -47,6 +47,9 @@
Statement,
ParsedStatement,
)
from google.cloud.spanner_dbapi.transaction_helper import (
_get_single_statement_result_checksum,
)
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

Expand Down Expand Up @@ -90,9 +93,8 @@ def __init__(self, connection):
self._row_count = _UNSET_COUNT
self.lastrowid = None
self.connection = connection
self.transaction_helper = self.connection._transaction_helper
self._is_closed = False
# the currently running SQL statement results checksum
self._checksum = None
# the number of rows to fetch at a time with fetchmany()
self.arraysize = 1

Expand Down Expand Up @@ -275,26 +277,22 @@ def _execute_in_rw_transaction(self, parsed_statement: ParsedStatement):
# For every other operation, we've got to ensure that
# any prior DDL statements were run.
self.connection.run_prior_DDL_statements()
statement = parsed_statement.statement
if self.connection._client_transaction_started:
(
self._result_set,
self._checksum,
) = self.connection.run_statement(parsed_statement.statement)

while True:
try:
self._itr = PeekIterator(self._result_set)
break
self._result_set = self.connection.run_statement(statement)
itr, self._itr = itertools.tee(PeekIterator(self._result_set), 2)
statement.checksum = _get_single_statement_result_checksum(itr)
self.transaction_helper._single_statements.append(statement)
return
except Aborted:
self.connection.retry_transaction()
except Exception as ex:
self.connection._statements.remove(parsed_statement.statement)
raise ex
self.transaction_helper.retry_transaction()
else:
self.connection.database.run_in_transaction(
self._do_execute_update_in_autocommit,
parsed_statement.statement.sql,
parsed_statement.statement.params or None,
statement.sql,
statement.params or None,
)

@check_not_closed
Expand Down Expand Up @@ -357,17 +355,12 @@ def fetchone(self):
sequence, or None when no more data is available."""
try:
res = next(self)
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
return res
except StopIteration:
return
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchone()

@check_not_closed
Expand All @@ -378,15 +371,10 @@ def fetchall(self):
res = []
try:
for row in self:
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(row)
res.append(row)
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchall()

return res
Expand All @@ -410,17 +398,12 @@ def fetchmany(self, size=None):
for _ in range(size):
try:
res = next(self)
if (
self.connection._client_transaction_started
and not self.connection.read_only
):
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break
except Aborted:
if not self.connection.read_only:
self.connection.retry_transaction()
self.transaction_helper.retry_transaction()
return self.fetchmany(size)

return items
Expand Down
2 changes: 0 additions & 2 deletions google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from . import client_side_statement_parser
from deprecated import deprecated

from .checksum import ResultsChecksum
from .exceptions import Error
from .parsed_statement import ParsedStatement, StatementType, Statement
from .types import DateStr, TimestampStr
Expand Down Expand Up @@ -230,7 +229,6 @@ def classify_statement(query, args=None):
query,
args,
get_param_types(args or None),
ResultsChecksum(),
)
if RE_DDL.match(query):
return ParsedStatement(StatementType.DDL, statement)
Expand Down
Loading

0 comments on commit 22633c7

Please sign in to comment.