diff --git a/tests/unit/test_atomic_counter.py b/tests/unit/test_atomic_counter.py index 92d10cac79..e8d8b6b7ce 100644 --- a/tests/unit/test_atomic_counter.py +++ b/tests/unit/test_atomic_counter.py @@ -15,6 +15,7 @@ import random import threading import unittest + from google.cloud.spanner_v1._helpers import AtomicCounter diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 8657b77d15..e8bed4a6bf 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -259,6 +259,10 @@ def test_commit_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ), ], ) self.assertEqual(request_options, RequestOptions()) @@ -577,6 +581,10 @@ def _test_batch_write_with_request_options( [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.1.1.1", + ), ], ) if request_options is None: diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 86f6ebee19..af6d500dad 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1204,7 +1204,7 @@ def _execute_partitioned_dml_helper( ("x-goog-spanner-route-to-leader", "true"), ( "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1", ), ], ) @@ -3217,6 +3217,7 @@ def __init__( self.directed_read_options = directed_read_options self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() + self.credentials = {} @property def _next_nth_request(self): diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index 280b8b24cf..a49d0521ce 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -15,25 +15,26 @@ import random import threading -from tests.mockserver_tests.mock_server_test_base import ( - MockServerTestBase, - add_select1_result, -) +from google.api_core.exceptions import Aborted +from google.rpc import code_pb2 + from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, BeginTransactionRequest, ExecuteSqlRequest, ) -from google.api_core.exceptions import Aborted -from google.rpc import code_pb2 from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, +) class TestRequestIDHeader(MockServerTestBase): def tearDown(self): self.database._x_goog_request_id_interceptor.reset() - def test_snapshot_read(self): + def test_snapshot_execute_sql(self): add_select1_result() if not getattr(self.database, "_interceptors", None): self.database._interceptors = MockServerTestBase._interceptors @@ -253,118 +254,6 @@ def test_database_execute_partitioned_dml_request_id(self): assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments - def test_snapshot_read_with_request_ids(self): - add_select1_result() - if not getattr(self.database, "_interceptors", None): - self.database._interceptors = MockServerTestBase._interceptors - with self.database.snapshot() as snapshot: - results = snapshot.read("select 1") - result_list = [] - for row in results: - result_list.append(row) - self.assertEqual(1, row[0]) - self.assertEqual(1, len(result_list)) - - requests = self.spanner_service.requests - self.assertEqual(2, len(requests), msg=requests) - self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) - self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) - - # requests = self.spanner_service.requests - # self.assertEqual(n * 2, len(requests), msg=requests) - - client_id = self.database._nth_client_id - channel_id = self.database._channel_id - got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() - - want_unary_segments = [ - ( - "/google.spanner.v1.Spanner/BatchCreateSessions", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), - ), - ( - "/google.spanner.v1.Spanner/GetSession", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), - ), - ] - assert got_unary_segments == want_unary_segments - - want_stream_segments = [ - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), - ), - ( - "/google.spanner.v1.Spanner/ExecuteStreamingSql", - (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), - ), - ] - assert got_stream_segments == want_stream_segments - def canonicalize_request_id_headers(self): src = self.database._x_goog_request_id_interceptor return src._stream_req_segments, src._unary_req_segments