Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add mock server tests #1217

Merged
merged 13 commits into from
Dec 5, 2024
Prev Previous commit
Next Next commit
chore: move to testing folder + fix formatting
  • Loading branch information
olavloite committed Oct 25, 2024
commit 10bff97c2f2d6cd74b2f3d6099e0e992144e0e90
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,25 @@

from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import struct_pb2 # type: ignore
import google.spanner.v1.spanner_pb2_grpc as spanner_grpc
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
import google.cloud.spanner_v1.types.result_set as result_set
import google.cloud.spanner_v1.types.transaction as transaction
import google.cloud.spanner_v1.types.commit_response as commit
import google.cloud.spanner_v1.types.spanner as spanner
from concurrent import futures
import grpc


class MockSpanner:
def __init__(self):
self.results = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql] = result

def get_result_as_partial_result_sets(self, sql: str) -> [result_set.PartialResultSet]:
def get_result_as_partial_result_sets(
self, sql: str
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.results.get(sql)
if result is None:
return []
Expand All @@ -38,14 +41,16 @@ def get_result_as_partial_result_sets(self, sql: str) -> [result_set.PartialResu
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata=result.metadata
partial.metadata = result.metadata
partial.values.extend(row)
partials.append(partial)
return partials


# An in-memory mock Spanner server that can be used for testing.
class SpannerServicer(spanner_grpc.SpannerServicer):
def __init__(self):
self.requests = []
self._requests = []
self.session_counter = 0
self.sessions = {}
self._mock_spanner = MockSpanner()
Expand All @@ -54,15 +59,21 @@ def __init__(self):
def mock_spanner(self):
return self._mock_spanner

@property
def requests(self):
return self._requests

def CreateSession(self, request, context):
self.requests.append(request)
self._requests.append(request)
return self.__create_session(request.database, request.session)

def BatchCreateSessions(self, request, context):
self.requests.append(request)
self._requests.append(request)
sessions = []
for i in range(request.session_count):
sessions.append(self.__create_session(request.database, request.session_template))
sessions.append(
self.__create_session(request.database, request.session_template)
)
return spanner.BatchCreateSessionsResponse(dict(session=sessions))

def __create_session(self, database: str, session_template: spanner.Session):
Expand All @@ -88,7 +99,7 @@ def ExecuteSql(self, request, context):
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
self.requests.append(request)
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
for result in partials:
yield result
Expand Down Expand Up @@ -122,6 +133,7 @@ def BatchWrite(self, request, context):
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
yield result


def start_mock_server() -> (grpc.Server, SpannerServicer, int):
spanner_servicer = SpannerServicer()
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
Expand All @@ -130,6 +142,7 @@ def start_mock_server() -> (grpc.Server, SpannerServicer, int):
spanner_server.start()
return spanner_server, spanner_servicer, port


if __name__ == "__main__":
server, _ = start_mock_server()
server.wait_for_termination()
Loading