From 9d9a9fddd819e7f9756bb0776f81995eacb5145d Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 3 Jan 2025 22:04:58 -0800 Subject: [PATCH] Fix discrepancy with api.batch_create_sessions automatically retrying assuming idempotency with headers too --- google/cloud/spanner_v1/pool.py | 70 +++++++++++++++---- google/cloud/spanner_v1/snapshot.py | 5 +- .../cloud/spanner_v1/testing/database_test.py | 5 ++ .../cloud/spanner_v1/testing/interceptors.py | 3 +- .../mockserver_tests/mock_server_test_base.py | 1 + .../test_request_id_header.py | 21 ++++-- 6 files changed, 84 insertions(+), 21 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 9aebedc901..446ab00ba2 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -15,10 +15,12 @@ """Pools managing shared Session objects.""" import datetime +import random import queue import time from google.cloud.exceptions import NotFound +from google.api_core.exceptions import ServiceUnavailable from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( @@ -251,13 +253,23 @@ def bind(self, database): f"Creating {request.session_count} sessions", span_event_attributes, ) - all_metadata = database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ) - resp = api.batch_create_sessions( - request=request, - metadata=all_metadata, - ) + nth_req = database._next_nth_request + + def create_sessions(attempt): + all_metadata = database.metadata_with_request_id( + nth_req, attempt, metadata + ) + return api.batch_create_sessions( + request=request, + metadata=all_metadata, + # Manually passing retry=None because otherwise any + # UNAVAILABLE retry will be retried without replenishing + # the metadata, hence this allows us to manually update + # the metadata using retry_on_unavailable. + retry=None, + ) + + resp = retry_on_unavailable(create_sessions) add_span_event( span, @@ -573,12 +585,27 @@ def bind(self, database): ) as span: returned_session_count = 0 while created_session_count < self.size: - resp = api.batch_create_sessions( - request=request, - metadata=database.metadata_with_request_id( - database._next_nth_request, 1, metadata - ), + nth_req = database._next_nth_request + print( + f"\033[36mPingingPool.nth_req: {nth_req}: {hex(id(self))}\033[00m" ) + + def create_sessions(attempt): + all_metadata = database.metadata_with_request_id( + nth_req, attempt, metadata + ) + return api.batch_create_sessions( + request=request, + metadata=all_metadata, + # Manually passing retry=None because otherwise any + # UNAVAILABLE retry will be retried without replenishing + # the metadata, hence this allows us to manually update + # the metadata using retry_on_unavailable. + retry=None, + ) + + resp = retry_on_unavailable(create_sessions) + for session_pb in resp.session: session = self._new_session() session._session_id = session_pb.name.split("/")[-1] @@ -812,3 +839,22 @@ def __enter__(self): def __exit__(self, *ignored): self._pool.put(self._session) + + +def retry_on_unavailable(fn, max=6): + """ + Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions, + each time passing in the iteration's ordinal number to signal + the nth attempt. It retries with exponential backoff with jitter. + """ + last_exc = None + for i in range(max): + try: + return fn(i + 1) + except ServiceUnavailable as exc: + last_exc = exc + time.sleep(i**2 + random.random()) + except: + raise + + raise last_exc diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 5d8f70e3e7..e1fbe59679 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -357,6 +357,7 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, + attempt=attempt, ) self._read_request_count += 1 if self._multi_use: @@ -379,6 +380,7 @@ def wrapped_restart(*args, **kwargs): trace_attributes, transaction=self, observability_options=observability_options, + attempt=attempt, ) self._read_request_count += 1 @@ -556,12 +558,11 @@ def execute_sql( attempt = AtomicCounter(0) def wrapped_restart(*args, **kwargs): - attempt.increment() restart = functools.partial( api.execute_streaming_sql, request=request, metadata=database.metadata_with_request_id( - nth_request, attempt.value, metadata + nth_request, attempt.increment(), metadata ), retry=retry, timeout=timeout, diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 80f040d7e0..4a6e94c88b 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -79,6 +79,7 @@ def spanner_api(self): channel = grpc.insecure_channel(self._instance.emulator_host) self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() self._interceptors.append(self._x_goog_request_id_interceptor) + # print("self._interceptors", self._interceptors) channel = grpc.intercept_channel(channel, *self._interceptors) transport = SpannerGrpcTransport(channel=channel) self._spanner_api = SpannerClient( @@ -115,3 +116,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials): client_options=client_options, transport=transport, ) + + def reset(self): + if self._x_goog_request_id_interceptor: + self._x_goog_request_id_interceptor.reset() diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index 4cd3abd306..a1ab53af40 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -90,7 +90,9 @@ def intercept(self, method, request_or_iterator, call_details): ) response_or_iterator = method(request_or_iterator, call_details) + print("call_details", call_details, "\n", response_or_iterator) streaming = getattr(response_or_iterator, "__iter__", None) is not None + print("x_append", call_details.method, x_goog_request_id) with self.__lock: if streaming: self._stream_req_segments.append( @@ -114,7 +116,6 @@ def stream_request_ids(self): def reset(self): self._stream_req_segments.clear() self._unary_req_segments.clear() - pass def parse_request_id(request_id_str): diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 24bbac0861..2f89415b55 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -57,6 +57,7 @@ def aborted_status() -> _Status: ) return status + # Creates an UNAVAILABLE status with the smallest possible retry delay. def unavailable_status() -> _Status: error = status_pb2.Status( diff --git a/tests/mockserver_tests/test_request_id_header.py b/tests/mockserver_tests/test_request_id_header.py index 494a4f3879..306d2cbf93 100644 --- a/tests/mockserver_tests/test_request_id_header.py +++ b/tests/mockserver_tests/test_request_id_header.py @@ -24,7 +24,10 @@ from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, - add_select1_result, aborted_status, add_error, unavailable_status, + add_select1_result, + aborted_status, + add_error, + unavailable_status, ) @@ -70,6 +73,13 @@ def test_snapshot_execute_sql(self): assert got_stream_segments == want_stream_segments def test_snapshot_read_concurrent(self): + # Trigger BatchCreateSessions firstly. + with self.database.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + for row in rows: + _ = row + + # The other requests can then proceed. def select1(): with self.database.snapshot() as snapshot: rows = snapshot.execute_sql("select 1") @@ -100,7 +110,7 @@ def select1(): break requests = self.spanner_service.requests - self.assertEqual(n + 1, len(requests), msg=requests) + self.assertEqual(2 + n * 2, len(requests), msg=requests) client_id = self.database._nth_client_id channel_id = self.database._channel_id @@ -112,6 +122,7 @@ def select1(): (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), ), ] + print("got_unary", got_unary_segments) assert got_unary_segments == want_unary_segments want_stream_segments = [ @@ -254,15 +265,13 @@ def test_unary_retryable_error(self): ) ] + print("got_unary_segments", got_unary_segments) assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments def test_streaming_retryable_error(self): add_select1_result() - # TODO: UNAVAILABLE errors are not correctly handled by the client lib. - # This is probably the reason behind - # https://github.com/googleapis/python-spanner/issues/1150. - # The fix + add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) if not getattr(self.database, "_interceptors", None):