-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* test: add mock server tests * chore: move to testing folder + fix formatting * refactor: move mock server tests to separate directory * feat: add database admin service Adds a DatabaseAdminService to the mock server and sets up a basic test case for this. Also removes the generated stubs in the grpc files, as these are not needed. * test: add DDL test * test: add async client tests * chore: remove async + add transaction handling * chore: cleanup * chore: run code formatter
- Loading branch information
Showing
10 changed files
with
2,605 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
name: Run Spanner tests against an in-mem mock server | ||
jobs: | ||
system-tests: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
- name: Setup Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: 3.12 | ||
- name: Install nox | ||
run: python -m pip install nox | ||
- name: Run mock server tests | ||
run: nox -s mockserver |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2024 Google LLC All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from google.longrunning import operations_pb2 as operations_pb2 | ||
from google.protobuf import empty_pb2 | ||
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc | ||
|
||
|
||
# An in-memory mock DatabaseAdmin server that can be used for testing. | ||
class DatabaseAdminServicer(database_admin_grpc.DatabaseAdminServicer): | ||
def __init__(self): | ||
self._requests = [] | ||
|
||
@property | ||
def requests(self): | ||
return self._requests | ||
|
||
def clear_requests(self): | ||
self._requests = [] | ||
|
||
def UpdateDatabaseDdl(self, request, context): | ||
self._requests.append(request) | ||
operation = operations_pb2.Operation() | ||
operation.done = True | ||
operation.name = "projects/test-project/operations/test-operation" | ||
operation.response.Pack(empty_pb2.Empty()) | ||
return operation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright 2024 Google LLC All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import base64 | ||
import grpc | ||
from concurrent import futures | ||
|
||
from google.protobuf import empty_pb2 | ||
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer | ||
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc | ||
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc | ||
import google.cloud.spanner_v1.types.commit_response as commit | ||
import google.cloud.spanner_v1.types.result_set as result_set | ||
import google.cloud.spanner_v1.types.spanner as spanner | ||
import google.cloud.spanner_v1.types.transaction as transaction | ||
|
||
|
||
class MockSpanner: | ||
def __init__(self): | ||
self.results = {} | ||
|
||
def add_result(self, sql: str, result: result_set.ResultSet): | ||
self.results[sql.lower().strip()] = result | ||
|
||
def get_result(self, sql: str) -> result_set.ResultSet: | ||
result = self.results.get(sql.lower().strip()) | ||
if result is None: | ||
raise ValueError(f"No result found for {sql}") | ||
return result | ||
|
||
def get_result_as_partial_result_sets( | ||
self, sql: str | ||
) -> [result_set.PartialResultSet]: | ||
result: result_set.ResultSet = self.get_result(sql) | ||
partials = [] | ||
first = True | ||
if len(result.rows) == 0: | ||
partial = result_set.PartialResultSet() | ||
partial.metadata = result.metadata | ||
partials.append(partial) | ||
else: | ||
for row in result.rows: | ||
partial = result_set.PartialResultSet() | ||
if first: | ||
partial.metadata = result.metadata | ||
partial.values.extend(row) | ||
partials.append(partial) | ||
partials[len(partials) - 1].stats = result.stats | ||
return partials | ||
|
||
|
||
# An in-memory mock Spanner server that can be used for testing. | ||
class SpannerServicer(spanner_grpc.SpannerServicer): | ||
def __init__(self): | ||
self._requests = [] | ||
self.session_counter = 0 | ||
self.sessions = {} | ||
self.transaction_counter = 0 | ||
self.transactions = {} | ||
self._mock_spanner = MockSpanner() | ||
|
||
@property | ||
def mock_spanner(self): | ||
return self._mock_spanner | ||
|
||
@property | ||
def requests(self): | ||
return self._requests | ||
|
||
def clear_requests(self): | ||
self._requests = [] | ||
|
||
def CreateSession(self, request, context): | ||
self._requests.append(request) | ||
return self.__create_session(request.database, request.session) | ||
|
||
def BatchCreateSessions(self, request, context): | ||
self._requests.append(request) | ||
sessions = [] | ||
for i in range(request.session_count): | ||
sessions.append( | ||
self.__create_session(request.database, request.session_template) | ||
) | ||
return spanner.BatchCreateSessionsResponse(dict(session=sessions)) | ||
|
||
def __create_session(self, database: str, session_template: spanner.Session): | ||
self.session_counter += 1 | ||
session = spanner.Session() | ||
session.name = database + "/sessions/" + str(self.session_counter) | ||
session.multiplexed = session_template.multiplexed | ||
session.labels.MergeFrom(session_template.labels) | ||
session.creator_role = session_template.creator_role | ||
self.sessions[session.name] = session | ||
return session | ||
|
||
def GetSession(self, request, context): | ||
self._requests.append(request) | ||
return spanner.Session() | ||
|
||
def ListSessions(self, request, context): | ||
self._requests.append(request) | ||
return [spanner.Session()] | ||
|
||
def DeleteSession(self, request, context): | ||
self._requests.append(request) | ||
return empty_pb2.Empty() | ||
|
||
def ExecuteSql(self, request, context): | ||
self._requests.append(request) | ||
return result_set.ResultSet() | ||
|
||
def ExecuteStreamingSql(self, request, context): | ||
self._requests.append(request) | ||
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql) | ||
for result in partials: | ||
yield result | ||
|
||
def ExecuteBatchDml(self, request, context): | ||
self._requests.append(request) | ||
response = spanner.ExecuteBatchDmlResponse() | ||
started_transaction = None | ||
if not request.transaction.begin == transaction.TransactionOptions(): | ||
started_transaction = self.__create_transaction( | ||
request.session, request.transaction.begin | ||
) | ||
first = True | ||
for statement in request.statements: | ||
result = self.mock_spanner.get_result(statement.sql) | ||
if first and started_transaction is not None: | ||
result = result_set.ResultSet( | ||
self.mock_spanner.get_result(statement.sql) | ||
) | ||
result.metadata = result_set.ResultSetMetadata(result.metadata) | ||
result.metadata.transaction = started_transaction | ||
response.result_sets.append(result) | ||
return response | ||
|
||
def Read(self, request, context): | ||
self._requests.append(request) | ||
return result_set.ResultSet() | ||
|
||
def StreamingRead(self, request, context): | ||
self._requests.append(request) | ||
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]: | ||
yield result | ||
|
||
def BeginTransaction(self, request, context): | ||
self._requests.append(request) | ||
return self.__create_transaction(request.session, request.options) | ||
|
||
def __create_transaction( | ||
self, session: str, options: transaction.TransactionOptions | ||
) -> transaction.Transaction: | ||
session = self.sessions[session] | ||
if session is None: | ||
raise ValueError(f"Session not found: {session}") | ||
self.transaction_counter += 1 | ||
id_bytes = bytes( | ||
f"{session.name}/transactions/{self.transaction_counter}", "UTF-8" | ||
) | ||
transaction_id = base64.urlsafe_b64encode(id_bytes) | ||
self.transactions[transaction_id] = options | ||
return transaction.Transaction(dict(id=transaction_id)) | ||
|
||
def Commit(self, request, context): | ||
self._requests.append(request) | ||
tx = self.transactions[request.transaction_id] | ||
if tx is None: | ||
raise ValueError(f"Transaction not found: {request.transaction_id}") | ||
del self.transactions[request.transaction_id] | ||
return commit.CommitResponse() | ||
|
||
def Rollback(self, request, context): | ||
self._requests.append(request) | ||
return empty_pb2.Empty() | ||
|
||
def PartitionQuery(self, request, context): | ||
self._requests.append(request) | ||
return spanner.PartitionResponse() | ||
|
||
def PartitionRead(self, request, context): | ||
self._requests.append(request) | ||
return spanner.PartitionResponse() | ||
|
||
def BatchWrite(self, request, context): | ||
self._requests.append(request) | ||
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]: | ||
yield result | ||
|
||
|
||
def start_mock_server() -> (grpc.Server, SpannerServicer, DatabaseAdminServicer, int): | ||
# Create a gRPC server. | ||
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
|
||
# Add the Spanner services to the gRPC server. | ||
spanner_servicer = SpannerServicer() | ||
spanner_grpc.add_SpannerServicer_to_server(spanner_servicer, spanner_server) | ||
database_admin_servicer = DatabaseAdminServicer() | ||
database_admin_grpc.add_DatabaseAdminServicer_to_server( | ||
database_admin_servicer, spanner_server | ||
) | ||
|
||
# Start the server on a random port. | ||
port = spanner_server.add_insecure_port("[::]:0") | ||
spanner_server.start() | ||
return spanner_server, spanner_servicer, database_admin_servicer, port |
Oops, something went wrong.