Skip to content

Commit

Permalink
More test plumbing
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 27, 2024
1 parent 3e47760 commit c4745fc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 120 deletions.
1 change: 1 addition & 0 deletions tests/unit/test_atomic_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import random
import threading
import unittest

from google.cloud.spanner_v1._helpers import AtomicCounter


Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ def test_commit_ok(self):
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
(
"x-goog-spanner-request-id",
f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1",
),
],
)
self.assertEqual(request_options, RequestOptions())
Expand Down Expand Up @@ -577,6 +581,10 @@ def _test_batch_write_with_request_options(
[
("google-cloud-resource-prefix", database.name),
("x-goog-spanner-route-to-leader", "true"),
(
"x-goog-spanner-request-id",
f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1",
),
],
)
if request_options is None:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,7 @@ def _execute_partitioned_dml_helper(
("x-goog-spanner-route-to-leader", "true"),
(
"x-goog-spanner-request-id",
f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1",
f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1",
),
],
)
Expand Down Expand Up @@ -3217,6 +3217,7 @@ def __init__(
self.directed_read_options = directed_read_options
self._nth_client_id = _Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter()
self.credentials = {}

@property
def _next_nth_request(self):
Expand Down
127 changes: 8 additions & 119 deletions tests/unit/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@
import random
import threading

from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
)
from google.api_core.exceptions import Aborted
from google.rpc import code_pb2

from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
BeginTransactionRequest,
ExecuteSqlRequest,
)
from google.api_core.exceptions import Aborted
from google.rpc import code_pb2
from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
)


class TestRequestIDHeader(MockServerTestBase):
def tearDown(self):
self.database._x_goog_request_id_interceptor.reset()

def test_snapshot_read(self):
def test_snapshot_execute_sql(self):
add_select1_result()
if not getattr(self.database, "_interceptors", None):
self.database._interceptors = MockServerTestBase._interceptors
Expand Down Expand Up @@ -253,118 +254,6 @@ def test_database_execute_partitioned_dml_request_id(self):
assert got_unary_segments == want_unary_segments
assert got_stream_segments == want_stream_segments

def test_snapshot_read_with_request_ids(self):
add_select1_result()
if not getattr(self.database, "_interceptors", None):
self.database._interceptors = MockServerTestBase._interceptors
with self.database.snapshot() as snapshot:
results = snapshot.read("select 1")
result_list = []
for row in results:
result_list.append(row)
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))

requests = self.spanner_service.requests
self.assertEqual(2, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))

# requests = self.spanner_service.requests
# self.assertEqual(n * 2, len(requests), msg=requests)

client_id = self.database._nth_client_id
channel_id = self.database._channel_id
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()

want_unary_segments = [
(
"/google.spanner.v1.Spanner/BatchCreateSessions",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1),
),
]
assert got_unary_segments == want_unary_segments

want_stream_segments = [
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1),
),
]
assert got_stream_segments == want_stream_segments

def canonicalize_request_id_headers(self):
src = self.database._x_goog_request_id_interceptor
return src._stream_req_segments, src._unary_req_segments
Expand Down

0 comments on commit c4745fc

Please sign in to comment.