From 17a58962302fbc515eeb7d66fbecade0f0bdf6b3 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 14 Jan 2025 23:32:00 -0800 Subject: [PATCH] Implement interceptor to wrap and increase x-goog-spanner-request-id attempts per retry This monkey patches SpannerClient methods to have an interceptor that increases the attempts per retry. The prelude though is that any callers to it must pass in the attempt value 0 so that each pass through will correctly increase the attempt field's value. --- google/cloud/spanner_v1/_helpers.py | 33 +++++++++++++++++- google/cloud/spanner_v1/database.py | 53 ++++++++++++++++++++++------- google/cloud/spanner_v1/pool.py | 12 ------- 3 files changed, 72 insertions(+), 26 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 756ab13ab1..8209c89a06 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -32,10 +32,11 @@ from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import JsonObject -from google.cloud.spanner_v1.request_id_header import with_request_id +from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id from google.rpc.error_details_pb2 import RetryInfo import random +from typing import Callable # Validation error messages NUMERIC_MAX_SCALE_ERR_MSG = ( @@ -648,3 +649,33 @@ def reset(self): def _metadata_with_request_id(*args, **kwargs): return with_request_id(*args, **kwargs) + + +class InterceptingHeaderInjector: + def __init__(self, original_callable: Callable): + self._original_callable = original_callable + + def __call__(self, *args, **kwargs): + metadata = kwargs.get("metadata", []) + # Find all the headers that match the x-goog-spanner-request-id + # header an on each retry increment the value. + all_metadata = [] + for key, value in metadata: + if key is REQ_ID_HEADER_KEY: + # Otherwise now increment the count for the attempt number. + splits = value.split(".") + attempt_plus_one = int(splits[-1]) + 1 + splits[-1] = str(attempt_plus_one) + value_before = value + value = ".".join(splits) + print("incrementing value on retry from", value_before, "to", value) + + all_metadata.append( + ( + key, + value, + ) + ) + + kwargs["metadata"] = all_metadata + return self._original_callable(*args, **kwargs) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index e40c0d54a3..f1a1327b0f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -54,6 +54,7 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, _metadata_with_request_id, + InterceptingHeaderInjector, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -427,6 +428,43 @@ def logger(self): @property def spanner_api(self): + """Helper for session-related API calls.""" + api = self._generate_spanner_api() + if not api: + return api + + # Now wrap each method's __call__ method with our wrapped one. + # This is how to deal with the fact that there are no proper gRPC + # interceptors for Python hence the remedy is to replace callables + # with our custom wrapper. + attrs = dir(api) + for attr_name in attrs: + mangled = attr_name.startswith("__") + if mangled: + continue + + non_public = attr_name.startswith("_") + if non_public: + continue + + attr = getattr(api, attr_name) + callable_attr = callable(attr) + if callable_attr is None: + continue + + # We should only be looking at bound methods to SpannerClient + # as those are the RPC invoking methods that need to be wrapped + + is_method = type(attr).__name__ == "method" + if not is_method: + continue + + print("attr_name", attr_name, "callable_attr", attr) + setattr(api, attr_name, InterceptingHeaderInjector(attr)) + + return api + + def _generate_spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: client_info = self._instance._client._client_info @@ -759,7 +797,8 @@ def execute_pdml(): add_span_event(span, "Starting BeginTransaction") begin_txn_attempt.increment() txn = api.begin_transaction( - session=session.name, options=txn_options, + session=session.name, + options=txn_options, metadata=self.metadata_with_request_id( begin_txn_nth_request, begin_txn_attempt.value, metadata ), @@ -794,18 +833,6 @@ def wrapped_method(*args, **kwargs): observability_options=self.observability_options, attempt=begin_txn_attempt, ) -<<<<<<< HEAD -======= - return method(*args, **kwargs) - - iterator = _restart_on_unavailable( - method=wrapped_method, - trace_name="CloudSpanner.ExecuteStreamingSql", - request=request, - transaction_selector=txn_selector, - observability_options=self.observability_options, - ) ->>>>>>> 54df502... Update tests result_set = StreamedResultSet(iterator) list(result_set) # consume all partials diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index eab2dbe809..a6e0afd64c 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -262,11 +262,6 @@ def create_sessions(attempt): return api.batch_create_sessions( request=request, metadata=all_metadata, - # Manually passing retry=None because otherwise any - # UNAVAILABLE retry will be retried without replenishing - # the metadata, hence this allows us to manually update - # the metadata using retry_on_unavailable. - retry=None, ) resp = retry_on_unavailable(create_sessions) @@ -578,13 +573,6 @@ def create_sessions(attempt): return api.batch_create_sessions( request=request, metadata=all_metadata, - # Manually passing retry=None because otherwise any - # UNAVAILABLE retry will be retried without replenishing - # the metadata, hence this allows us to manually update - # the metadata using retry_on_unavailable. - # TODO: Figure out how to intercept and monkey patch the internals - # of the gRPC transport. - retry=None, ) resp = retry_on_unavailable(create_sessions)