Skip to content

Commit

Permalink
test: add mock server tests (#1217)
Browse files Browse the repository at this point in the history
* test: add mock server tests

* chore: move to testing folder + fix formatting

* refactor: move mock server tests to separate directory

* feat: add database admin service

Adds a DatabaseAdminService to the mock server and sets up a basic
test case for this.

Also removes the generated stubs in the grpc files, as these are
not needed.

* test: add DDL test

* test: add async client tests

* chore: remove async + add transaction handling

* chore: cleanup

* chore: run code formatter
  • Loading branch information
olavloite authored Dec 5, 2024
1 parent 829b799 commit 7df93ca
Show file tree
Hide file tree
Showing 10 changed files with 2,605 additions and 1 deletion.
21 changes: 21 additions & 0 deletions .github/workflows/mock_server_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
on:
push:
branches:
- main
pull_request:
name: Run Spanner tests against an in-mem mock server
jobs:
system-tests:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.12
- name: Install nox
run: python -m pip install nox
- name: Run mock server tests
run: nox -s mockserver
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class Database(object):
statements in 'ddl_statements' above.
"""

_spanner_api = None
_spanner_api: SpannerClient = None

def __init__(
self,
Expand Down
Empty file.
38 changes: 38 additions & 0 deletions google/cloud/spanner_v1/testing/mock_database_admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.longrunning import operations_pb2 as operations_pb2
from google.protobuf import empty_pb2
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc


# An in-memory mock DatabaseAdmin server that can be used for testing.
class DatabaseAdminServicer(database_admin_grpc.DatabaseAdminServicer):
def __init__(self):
self._requests = []

@property
def requests(self):
return self._requests

def clear_requests(self):
self._requests = []

def UpdateDatabaseDdl(self, request, context):
self._requests.append(request)
operation = operations_pb2.Operation()
operation.done = True
operation.name = "projects/test-project/operations/test-operation"
operation.response.Pack(empty_pb2.Empty())
return operation
216 changes: 216 additions & 0 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import grpc
from concurrent import futures

from google.protobuf import empty_pb2
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
import google.cloud.spanner_v1.types.commit_response as commit
import google.cloud.spanner_v1.types.result_set as result_set
import google.cloud.spanner_v1.types.spanner as spanner
import google.cloud.spanner_v1.types.transaction as transaction


class MockSpanner:
def __init__(self):
self.results = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql.lower().strip()] = result

def get_result(self, sql: str) -> result_set.ResultSet:
result = self.results.get(sql.lower().strip())
if result is None:
raise ValueError(f"No result found for {sql}")
return result

def get_result_as_partial_result_sets(
self, sql: str
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
partial.metadata = result.metadata
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
return partials


# An in-memory mock Spanner server that can be used for testing.
class SpannerServicer(spanner_grpc.SpannerServicer):
def __init__(self):
self._requests = []
self.session_counter = 0
self.sessions = {}
self.transaction_counter = 0
self.transactions = {}
self._mock_spanner = MockSpanner()

@property
def mock_spanner(self):
return self._mock_spanner

@property
def requests(self):
return self._requests

def clear_requests(self):
self._requests = []

def CreateSession(self, request, context):
self._requests.append(request)
return self.__create_session(request.database, request.session)

def BatchCreateSessions(self, request, context):
self._requests.append(request)
sessions = []
for i in range(request.session_count):
sessions.append(
self.__create_session(request.database, request.session_template)
)
return spanner.BatchCreateSessionsResponse(dict(session=sessions))

def __create_session(self, database: str, session_template: spanner.Session):
self.session_counter += 1
session = spanner.Session()
session.name = database + "/sessions/" + str(self.session_counter)
session.multiplexed = session_template.multiplexed
session.labels.MergeFrom(session_template.labels)
session.creator_role = session_template.creator_role
self.sessions[session.name] = session
return session

def GetSession(self, request, context):
self._requests.append(request)
return spanner.Session()

def ListSessions(self, request, context):
self._requests.append(request)
return [spanner.Session()]

def DeleteSession(self, request, context):
self._requests.append(request)
return empty_pb2.Empty()

def ExecuteSql(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
self._requests.append(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
for result in partials:
yield result

def ExecuteBatchDml(self, request, context):
self._requests.append(request)
response = spanner.ExecuteBatchDmlResponse()
started_transaction = None
if not request.transaction.begin == transaction.TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
first = True
for statement in request.statements:
result = self.mock_spanner.get_result(statement.sql)
if first and started_transaction is not None:
result = result_set.ResultSet(
self.mock_spanner.get_result(statement.sql)
)
result.metadata = result_set.ResultSetMetadata(result.metadata)
result.metadata.transaction = started_transaction
response.result_sets.append(result)
return response

def Read(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def StreamingRead(self, request, context):
self._requests.append(request)
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]:
yield result

def BeginTransaction(self, request, context):
self._requests.append(request)
return self.__create_transaction(request.session, request.options)

def __create_transaction(
self, session: str, options: transaction.TransactionOptions
) -> transaction.Transaction:
session = self.sessions[session]
if session is None:
raise ValueError(f"Session not found: {session}")
self.transaction_counter += 1
id_bytes = bytes(
f"{session.name}/transactions/{self.transaction_counter}", "UTF-8"
)
transaction_id = base64.urlsafe_b64encode(id_bytes)
self.transactions[transaction_id] = options
return transaction.Transaction(dict(id=transaction_id))

def Commit(self, request, context):
self._requests.append(request)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
return commit.CommitResponse()

def Rollback(self, request, context):
self._requests.append(request)
return empty_pb2.Empty()

def PartitionQuery(self, request, context):
self._requests.append(request)
return spanner.PartitionResponse()

def PartitionRead(self, request, context):
self._requests.append(request)
return spanner.PartitionResponse()

def BatchWrite(self, request, context):
self._requests.append(request)
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
yield result


def start_mock_server() -> (grpc.Server, SpannerServicer, DatabaseAdminServicer, int):
# Create a gRPC server.
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

# Add the Spanner services to the gRPC server.
spanner_servicer = SpannerServicer()
spanner_grpc.add_SpannerServicer_to_server(spanner_servicer, spanner_server)
database_admin_servicer = DatabaseAdminServicer()
database_admin_grpc.add_DatabaseAdminServicer_to_server(
database_admin_servicer, spanner_server
)

# Start the server on a random port.
port = spanner_server.add_insecure_port("[::]:0")
spanner_server.start()
return spanner_server, spanner_servicer, database_admin_servicer, port
Loading

0 comments on commit 7df93ca

Please sign in to comment.