Skip to content

Commit

Permalink
test: add mock server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite committed Oct 24, 2024
1 parent d3c6464 commit f6f6080
Show file tree
Hide file tree
Showing 7 changed files with 1,167 additions and 0 deletions.
14 changes: 14 additions & 0 deletions google/spanner/v1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Spanner Servicer

The Spanner server definition files were generated with these commands:

```shell
pip install grpcio-tools
git clone git@github.com:googleapis/googleapis.git
cd googleapis
python -m grpc_tools.protoc \
-I . \
--python_out=. --pyi_out=. --grpc_python_out=. \
./google/spanner/v1/*.proto
```

Empty file added google/spanner/v1/__init__.py
Empty file.
906 changes: 906 additions & 0 deletions google/spanner/v1/spanner_pb2_grpc.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions tests/unit/mockserver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
135 changes: 135 additions & 0 deletions tests/unit/mockserver/mock_spanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import struct_pb2 # type: ignore
import google.spanner.v1.spanner_pb2_grpc as spanner_grpc
import google.cloud.spanner_v1.types.result_set as result_set
import google.cloud.spanner_v1.types.transaction as transaction
import google.cloud.spanner_v1.types.commit_response as commit
import google.cloud.spanner_v1.types.spanner as spanner
from concurrent import futures
import grpc

class MockSpanner:
def __init__(self):
self.results = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql] = result

def get_result_as_partial_result_sets(self, sql: str) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.results.get(sql)
if result is None:
return []
partials = []
first = True
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata=result.metadata
partial.values.extend(row)
partials.append(partial)
return partials

class SpannerServicer(spanner_grpc.SpannerServicer):
def __init__(self):
self.requests = []
self.session_counter = 0
self.sessions = {}
self._mock_spanner = MockSpanner()

@property
def mock_spanner(self):
return self._mock_spanner

def CreateSession(self, request, context):
self.requests.append(request)
return self.__create_session(request.database, request.session)

def BatchCreateSessions(self, request, context):
self.requests.append(request)
sessions = []
for i in range(request.session_count):
sessions.append(self.__create_session(request.database, request.session_template))
return spanner.BatchCreateSessionsResponse(dict(session=sessions))

def __create_session(self, database: str, session_template: spanner.Session):
self.session_counter += 1
session = spanner.Session()
session.name = database + "/sessions/" + str(self.session_counter)
session.multiplexed = session_template.multiplexed
session.labels.MergeFrom(session_template.labels)
session.creator_role = session_template.creator_role
self.sessions[session.name] = session
return session

def GetSession(self, request, context):
return spanner.Session()

def ListSessions(self, request, context):
return [spanner.Session()]

def DeleteSession(self, request, context):
return empty_pb2.Empty()

def ExecuteSql(self, request, context):
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
self.requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
for result in partials:
yield result

def ExecuteBatchDml(self, request, context):
return spanner.ExecuteBatchDmlResponse()

def Read(self, request, context):
return result_set.ResultSet()

def StreamingRead(self, request, context):
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]:
yield result

def BeginTransaction(self, request, context):
return transaction.Transaction()

def Commit(self, request, context):
return commit.CommitResponse()

def Rollback(self, request, context):
return empty_pb2.Empty()

def PartitionQuery(self, request, context):
return spanner.PartitionResponse()

def PartitionRead(self, request, context):
return spanner.PartitionResponse()

def BatchWrite(self, request, context):
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
yield result

def start_mock_server() -> (grpc.Server, SpannerServicer, int):
spanner_servicer = SpannerServicer()
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
spanner_grpc.add_SpannerServicer_to_server(spanner_servicer, spanner_server)
port = spanner_server.add_insecure_port("[::]:0")
spanner_server.start()
return spanner_server, spanner_servicer, port

if __name__ == "__main__":
server, _ = start_mock_server()
server.wait_for_termination()
Empty file.
97 changes: 97 additions & 0 deletions tests/unit/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from tests.unit.mockserver.mock_spanner import start_mock_server, \
SpannerServicer
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import Client, FixedSizePool
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc

class TestBasics(unittest.TestCase):
server: grpc.Server = None
service: SpannerServicer = None
port: int = None

def __init__(self, *args, **kwargs):
super(TestBasics, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None

@classmethod
def setUpClass(cls):
TestBasics.server, TestBasics.service, TestBasics.port = start_mock_server()

@classmethod
def tearDownClass(cls):
if TestBasics.server is not None:
TestBasics.server.stop(grace=None)
TestBasics.server = None

@property
def client(self) -> Client:
if self._client is None:
self._client = Client(
project="test-project",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
api_endpoint="localhost:" + str(TestBasics.port),
)
)
return self._client

@property
def instance(self) -> Instance:
if self._instance is None:
self._instance = self.client.instance("test-instance")
return self._instance

@property
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database",
pool=FixedSizePool(size=10)
)
return self._database

def test_select1(self):
result = result_set.ResultSet(dict(
metadata=result_set.ResultSetMetadata(dict(
row_type=spanner_type.StructType(dict(
fields=[spanner_type.StructType.Field(dict(
name="c",
type=spanner_type.Type(dict(
code=spanner_type.TypeCode.INT64)
))
)]
)))
),
))
result.rows.extend(["1"])
TestBasics.service.mock_spanner.add_result("select 1", result)
with self.database.snapshot() as snapshot:
results = snapshot.execute_sql("select 1")
result_list = []
for row in results:
result_list.append(row)
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))

0 comments on commit f6f6080

Please sign in to comment.