Skip to content

Commit

Permalink
Fix discrepancy with api.batch_create_sessions automatically retrying…
Browse files Browse the repository at this point in the history
… assuming idempotency with headers too
  • Loading branch information
odeke-em committed Jan 4, 2025
1 parent 04304f1 commit 9d9a9fd
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 21 deletions.
70 changes: 58 additions & 12 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""Pools managing shared Session objects."""

import datetime
import random
import queue
import time

from google.cloud.exceptions import NotFound
from google.api_core.exceptions import ServiceUnavailable
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
from google.cloud.spanner_v1 import Session
from google.cloud.spanner_v1._helpers import (
Expand Down Expand Up @@ -251,13 +253,23 @@ def bind(self, database):
f"Creating {request.session_count} sessions",
span_event_attributes,
)
all_metadata = database.metadata_with_request_id(
database._next_nth_request, 1, metadata
)
resp = api.batch_create_sessions(
request=request,
metadata=all_metadata,
)
nth_req = database._next_nth_request

def create_sessions(attempt):
all_metadata = database.metadata_with_request_id(
nth_req, attempt, metadata
)
return api.batch_create_sessions(
request=request,
metadata=all_metadata,
# Manually passing retry=None because otherwise any
# UNAVAILABLE retry will be retried without replenishing
# the metadata, hence this allows us to manually update
# the metadata using retry_on_unavailable.
retry=None,
)

resp = retry_on_unavailable(create_sessions)

add_span_event(
span,
Expand Down Expand Up @@ -573,12 +585,27 @@ def bind(self, database):
) as span:
returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
nth_req = database._next_nth_request
print(
f"\033[36mPingingPool.nth_req: {nth_req}: {hex(id(self))}\033[00m"
)

def create_sessions(attempt):
all_metadata = database.metadata_with_request_id(
nth_req, attempt, metadata
)
return api.batch_create_sessions(
request=request,
metadata=all_metadata,
# Manually passing retry=None because otherwise any
# UNAVAILABLE retry will be retried without replenishing
# the metadata, hence this allows us to manually update
# the metadata using retry_on_unavailable.
retry=None,
)

resp = retry_on_unavailable(create_sessions)

for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
Expand Down Expand Up @@ -812,3 +839,22 @@ def __enter__(self):

def __exit__(self, *ignored):
self._pool.put(self._session)


def retry_on_unavailable(fn, max=6):
"""
Retries `fn` to a maximum of `max` times on encountering UNAVAILABLE exceptions,
each time passing in the iteration's ordinal number to signal
the nth attempt. It retries with exponential backoff with jitter.
"""
last_exc = None
for i in range(max):
try:
return fn(i + 1)
except ServiceUnavailable as exc:
last_exc = exc
time.sleep(i**2 + random.random())
except:
raise

raise last_exc
5 changes: 3 additions & 2 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def wrapped_restart(*args, **kwargs):
trace_attributes,
transaction=self,
observability_options=observability_options,
attempt=attempt,
)
self._read_request_count += 1
if self._multi_use:
Expand All @@ -379,6 +380,7 @@ def wrapped_restart(*args, **kwargs):
trace_attributes,
transaction=self,
observability_options=observability_options,
attempt=attempt,
)

self._read_request_count += 1
Expand Down Expand Up @@ -556,12 +558,11 @@ def execute_sql(
attempt = AtomicCounter(0)

def wrapped_restart(*args, **kwargs):
attempt.increment()
restart = functools.partial(
api.execute_streaming_sql,
request=request,
metadata=database.metadata_with_request_id(
nth_request, attempt.value, metadata
nth_request, attempt.increment(), metadata
),
retry=retry,
timeout=timeout,
Expand Down
5 changes: 5 additions & 0 deletions google/cloud/spanner_v1/testing/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def spanner_api(self):
channel = grpc.insecure_channel(self._instance.emulator_host)
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
self._interceptors.append(self._x_goog_request_id_interceptor)
# print("self._interceptors", self._interceptors)
channel = grpc.intercept_channel(channel, *self._interceptors)
transport = SpannerGrpcTransport(channel=channel)
self._spanner_api = SpannerClient(
Expand Down Expand Up @@ -115,3 +116,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials):
client_options=client_options,
transport=transport,
)

def reset(self):
if self._x_goog_request_id_interceptor:
self._x_goog_request_id_interceptor.reset()
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/testing/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def intercept(self, method, request_or_iterator, call_details):
)

response_or_iterator = method(request_or_iterator, call_details)
print("call_details", call_details, "\n", response_or_iterator)
streaming = getattr(response_or_iterator, "__iter__", None) is not None
print("x_append", call_details.method, x_goog_request_id)
with self.__lock:
if streaming:
self._stream_req_segments.append(
Expand All @@ -114,7 +116,6 @@ def stream_request_ids(self):
def reset(self):
self._stream_req_segments.clear()
self._unary_req_segments.clear()
pass


def parse_request_id(request_id_str):
Expand Down
1 change: 1 addition & 0 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def aborted_status() -> _Status:
)
return status


# Creates an UNAVAILABLE status with the smallest possible retry delay.
def unavailable_status() -> _Status:
error = status_pb2.Status(
Expand Down
21 changes: 15 additions & 6 deletions tests/mockserver_tests/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result, aborted_status, add_error, unavailable_status,
add_select1_result,
aborted_status,
add_error,
unavailable_status,
)


Expand Down Expand Up @@ -70,6 +73,13 @@ def test_snapshot_execute_sql(self):
assert got_stream_segments == want_stream_segments

def test_snapshot_read_concurrent(self):
# Trigger BatchCreateSessions firstly.
with self.database.snapshot() as snapshot:
rows = snapshot.execute_sql("select 1")
for row in rows:
_ = row

# The other requests can then proceed.
def select1():
with self.database.snapshot() as snapshot:
rows = snapshot.execute_sql("select 1")
Expand Down Expand Up @@ -100,7 +110,7 @@ def select1():
break

requests = self.spanner_service.requests
self.assertEqual(n + 1, len(requests), msg=requests)
self.assertEqual(2 + n * 2, len(requests), msg=requests)

client_id = self.database._nth_client_id
channel_id = self.database._channel_id
Expand All @@ -112,6 +122,7 @@ def select1():
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
),
]
print("got_unary", got_unary_segments)
assert got_unary_segments == want_unary_segments

want_stream_segments = [
Expand Down Expand Up @@ -254,15 +265,13 @@ def test_unary_retryable_error(self):
)
]

print("got_unary_segments", got_unary_segments)
assert got_unary_segments == want_unary_segments
assert got_stream_segments == want_stream_segments

def test_streaming_retryable_error(self):
add_select1_result()
# TODO: UNAVAILABLE errors are not correctly handled by the client lib.
# This is probably the reason behind
# https://github.com/googleapis/python-spanner/issues/1150.
# The fix
add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status())
add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status())

if not getattr(self.database, "_interceptors", None):
Expand Down

0 comments on commit 9d9a9fd

Please sign in to comment.