-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
1,167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Spanner Servicer | ||
|
||
The Spanner server definition files were generated with these commands: | ||
|
||
```shell | ||
pip install grpcio-tools | ||
git clone git@github.com:googleapis/googleapis.git | ||
cd googleapis | ||
python -m grpc_tools.protoc \ | ||
-I . \ | ||
--python_out=. --pyi_out=. --grpc_python_out=. \ | ||
./google/spanner/v1/*.proto | ||
``` | ||
|
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright 2024 Google LLC All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from google.protobuf import empty_pb2 # type: ignore | ||
from google.protobuf import struct_pb2 # type: ignore | ||
import google.spanner.v1.spanner_pb2_grpc as spanner_grpc | ||
import google.cloud.spanner_v1.types.result_set as result_set | ||
import google.cloud.spanner_v1.types.transaction as transaction | ||
import google.cloud.spanner_v1.types.commit_response as commit | ||
import google.cloud.spanner_v1.types.spanner as spanner | ||
from concurrent import futures | ||
import grpc | ||
|
||
class MockSpanner: | ||
def __init__(self): | ||
self.results = {} | ||
|
||
def add_result(self, sql: str, result: result_set.ResultSet): | ||
self.results[sql] = result | ||
|
||
def get_result_as_partial_result_sets(self, sql: str) -> [result_set.PartialResultSet]: | ||
result: result_set.ResultSet = self.results.get(sql) | ||
if result is None: | ||
return [] | ||
partials = [] | ||
first = True | ||
for row in result.rows: | ||
partial = result_set.PartialResultSet() | ||
if first: | ||
partial.metadata=result.metadata | ||
partial.values.extend(row) | ||
partials.append(partial) | ||
return partials | ||
|
||
class SpannerServicer(spanner_grpc.SpannerServicer): | ||
def __init__(self): | ||
self.requests = [] | ||
self.session_counter = 0 | ||
self.sessions = {} | ||
self._mock_spanner = MockSpanner() | ||
|
||
@property | ||
def mock_spanner(self): | ||
return self._mock_spanner | ||
|
||
def CreateSession(self, request, context): | ||
self.requests.append(request) | ||
return self.__create_session(request.database, request.session) | ||
|
||
def BatchCreateSessions(self, request, context): | ||
self.requests.append(request) | ||
sessions = [] | ||
for i in range(request.session_count): | ||
sessions.append(self.__create_session(request.database, request.session_template)) | ||
return spanner.BatchCreateSessionsResponse(dict(session=sessions)) | ||
|
||
def __create_session(self, database: str, session_template: spanner.Session): | ||
self.session_counter += 1 | ||
session = spanner.Session() | ||
session.name = database + "/sessions/" + str(self.session_counter) | ||
session.multiplexed = session_template.multiplexed | ||
session.labels.MergeFrom(session_template.labels) | ||
session.creator_role = session_template.creator_role | ||
self.sessions[session.name] = session | ||
return session | ||
|
||
def GetSession(self, request, context): | ||
return spanner.Session() | ||
|
||
def ListSessions(self, request, context): | ||
return [spanner.Session()] | ||
|
||
def DeleteSession(self, request, context): | ||
return empty_pb2.Empty() | ||
|
||
def ExecuteSql(self, request, context): | ||
return result_set.ResultSet() | ||
|
||
def ExecuteStreamingSql(self, request, context): | ||
self.requests.append(request) | ||
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql) | ||
for result in partials: | ||
yield result | ||
|
||
def ExecuteBatchDml(self, request, context): | ||
return spanner.ExecuteBatchDmlResponse() | ||
|
||
def Read(self, request, context): | ||
return result_set.ResultSet() | ||
|
||
def StreamingRead(self, request, context): | ||
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]: | ||
yield result | ||
|
||
def BeginTransaction(self, request, context): | ||
return transaction.Transaction() | ||
|
||
def Commit(self, request, context): | ||
return commit.CommitResponse() | ||
|
||
def Rollback(self, request, context): | ||
return empty_pb2.Empty() | ||
|
||
def PartitionQuery(self, request, context): | ||
return spanner.PartitionResponse() | ||
|
||
def PartitionRead(self, request, context): | ||
return spanner.PartitionResponse() | ||
|
||
def BatchWrite(self, request, context): | ||
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]: | ||
yield result | ||
|
||
def start_mock_server() -> (grpc.Server, SpannerServicer, int): | ||
spanner_servicer = SpannerServicer() | ||
spanner_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
spanner_grpc.add_SpannerServicer_to_server(spanner_servicer, spanner_server) | ||
port = spanner_server.add_insecure_port("[::]:0") | ||
spanner_server.start() | ||
return spanner_server, spanner_servicer, port | ||
|
||
if __name__ == "__main__": | ||
server, _ = start_mock_server() | ||
server.wait_for_termination() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2024 Google LLC All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from tests.unit.mockserver.mock_spanner import start_mock_server, \ | ||
SpannerServicer | ||
import google.cloud.spanner_v1.types.type as spanner_type | ||
import google.cloud.spanner_v1.types.result_set as result_set | ||
from google.api_core.client_options import ClientOptions | ||
from google.auth.credentials import AnonymousCredentials | ||
from google.cloud.spanner_v1 import Client, FixedSizePool | ||
from google.cloud.spanner_v1.database import Database | ||
from google.cloud.spanner_v1.instance import Instance | ||
import grpc | ||
|
||
class TestBasics(unittest.TestCase): | ||
server: grpc.Server = None | ||
service: SpannerServicer = None | ||
port: int = None | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(TestBasics, self).__init__(*args, **kwargs) | ||
self._client = None | ||
self._instance = None | ||
self._database = None | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
TestBasics.server, TestBasics.service, TestBasics.port = start_mock_server() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
if TestBasics.server is not None: | ||
TestBasics.server.stop(grace=None) | ||
TestBasics.server = None | ||
|
||
@property | ||
def client(self) -> Client: | ||
if self._client is None: | ||
self._client = Client( | ||
project="test-project", | ||
credentials=AnonymousCredentials(), | ||
client_options=ClientOptions( | ||
api_endpoint="localhost:" + str(TestBasics.port), | ||
) | ||
) | ||
return self._client | ||
|
||
@property | ||
def instance(self) -> Instance: | ||
if self._instance is None: | ||
self._instance = self.client.instance("test-instance") | ||
return self._instance | ||
|
||
@property | ||
def database(self) -> Database: | ||
if self._database is None: | ||
self._database = self.instance.database( | ||
"test-database", | ||
pool=FixedSizePool(size=10) | ||
) | ||
return self._database | ||
|
||
def test_select1(self): | ||
result = result_set.ResultSet(dict( | ||
metadata=result_set.ResultSetMetadata(dict( | ||
row_type=spanner_type.StructType(dict( | ||
fields=[spanner_type.StructType.Field(dict( | ||
name="c", | ||
type=spanner_type.Type(dict( | ||
code=spanner_type.TypeCode.INT64) | ||
)) | ||
)] | ||
))) | ||
), | ||
)) | ||
result.rows.extend(["1"]) | ||
TestBasics.service.mock_spanner.add_result("select 1", result) | ||
with self.database.snapshot() as snapshot: | ||
results = snapshot.execute_sql("select 1") | ||
result_list = [] | ||
for row in results: | ||
result_list.append(row) | ||
self.assertEqual(1, row[0]) | ||
self.assertEqual(1, len(result_list)) | ||
|