diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 5c858320c4..1e10e1df73 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -25,7 +25,6 @@ import google.auth.credentials from google.api_core.retry import Retry from google.api_core.retry import if_exception_type -from google.auth.aio.credentials import AnonymousCredentials from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted from google.api_core import gapic_v1 @@ -42,7 +41,7 @@ from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1.transaction import BatchTransactionId -from google.cloud.spanner_v1 import ExecuteSqlRequest, SpannerAsyncClient +from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import TransactionSelector @@ -144,7 +143,6 @@ class Database(object): """ _spanner_api: SpannerClient = None - _spanner_async_api: SpannerAsyncClient = None def __init__( self, @@ -440,28 +438,6 @@ def spanner_api(self): ) return self._spanner_api - @property - def spanner_async_api(self): - if self._spanner_async_api is None: - client_info = self._instance._client._client_info - client_options = self._instance._client._client_options - if self._instance.emulator_host is not None: - channel=grpc.aio.insecure_channel(target=self._instance.emulator_host) - transport = SpannerGrpcTransport(channel=channel) - self._spanner_async_api = SpannerAsyncClient( - client_info=client_info, transport=transport - ) - return self._spanner_async_api - credentials = self._instance._client.credentials - if isinstance(credentials, google.auth.credentials.Scoped): - credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,)) - self._spanner_async_api = SpannerAsyncClient( - credentials=credentials, - client_info=client_info, - client_options=client_options, - ) - return self._spanner_async_api - def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented diff --git a/google/cloud/spanner_v1/services/spanner/transports/grpc.py b/google/cloud/spanner_v1/services/spanner/transports/grpc.py index fce1002942..a2afa32174 100644 --- a/google/cloud/spanner_v1/services/spanner/transports/grpc.py +++ b/google/cloud/spanner_v1/services/spanner/transports/grpc.py @@ -127,7 +127,7 @@ def __init__( if client_cert_source: warnings.warn("client_cert_source is deprecated", DeprecationWarning) - if isinstance(channel, grpc.Channel) or isinstance(channel, grpc.aio.Channel): + if isinstance(channel, grpc.Channel): # Ignore credentials if a channel was passed. credentials = None self._ignore_credentials = True diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index 5484615dd2..86b4a96b67 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 from concurrent import futures from google.protobuf import empty_pb2 @@ -31,22 +31,32 @@ def __init__(self): self.results = {} def add_result(self, sql: str, result: result_set.ResultSet): - self.results[sql.lower()] = result + self.results[sql.lower().strip()] = result + + def get_result(self, sql: str) -> result_set.ResultSet: + result = self.results.get(sql.lower().strip()) + if result is None: + raise ValueError(f"No result found for {sql}") + return result def get_result_as_partial_result_sets( self, sql: str ) -> [result_set.PartialResultSet]: - result: result_set.ResultSet = self.results.get(sql.lower()) - if result is None: - return [] + result: result_set.ResultSet = self.get_result(sql) partials = [] first = True - for row in result.rows: + if len(result.rows) == 0: partial = result_set.PartialResultSet() - if first: - partial.metadata = result.metadata - partial.values.extend(row) + partial.metadata = result.metadata partials.append(partial) + else: + for row in result.rows: + partial = result_set.PartialResultSet() + if first: + partial.metadata = result.metadata + partial.values.extend(row) + partials.append(partial) + partials[len(partials) - 1].stats = result.stats return partials @@ -56,6 +66,8 @@ def __init__(self): self._requests = [] self.session_counter = 0 self.sessions = {} + self.transaction_counter = 0 + self.transactions = {} self._mock_spanner = MockSpanner() @property @@ -93,18 +105,20 @@ def __create_session(self, database: str, session_template: spanner.Session): return session def GetSession(self, request, context): + self._requests.append(request) return spanner.Session() def ListSessions(self, request, context): + self._requests.append(request) return [spanner.Session()] def DeleteSession(self, request, context): + self._requests.append(request) return empty_pb2.Empty() def ExecuteSql(self, request, context): self._requests.append(request) - result: result_set.ResultSet = self.mock_spanner.results.get(request.sql.lower()) - return result + return result_set.ResultSet() def ExecuteStreamingSql(self, request, context): self._requests.append(request) @@ -113,31 +127,74 @@ def ExecuteStreamingSql(self, request, context): yield result def ExecuteBatchDml(self, request, context): - return spanner.ExecuteBatchDmlResponse() + self._requests.append(request) + response = spanner.ExecuteBatchDmlResponse() + started_transaction = None + if not request.transaction.begin == transaction.TransactionOptions(): + started_transaction = self.__create_transaction( + request.session, request.transaction.begin + ) + first = True + for statement in request.statements: + result = self.mock_spanner.get_result(statement.sql) + if first and started_transaction is not None: + result = result_set.ResultSet( + self.mock_spanner.get_result(statement.sql) + ) + result.metadata = result_set.ResultSetMetadata(result.metadata) + result.metadata.transaction = started_transaction + response.result_sets.append(result) + return response def Read(self, request, context): + self._requests.append(request) return result_set.ResultSet() def StreamingRead(self, request, context): + self._requests.append(request) for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]: yield result def BeginTransaction(self, request, context): - return transaction.Transaction() + self._requests.append(request) + return self.__create_transaction(request.session, request.options) + + def __create_transaction( + self, session: str, options: transaction.TransactionOptions + ) -> transaction.Transaction: + session = self.sessions[session] + if session is None: + raise ValueError(f"Session not found: {session}") + self.transaction_counter += 1 + id_bytes = bytes( + f"{session.name}/transactions/{self.transaction_counter}", "UTF-8" + ) + transaction_id = base64.urlsafe_b64encode(id_bytes) + self.transactions[transaction_id] = options + return transaction.Transaction(dict(id=transaction_id)) def Commit(self, request, context): + self._requests.append(request) + tx = self.transactions[request.transaction_id] + if tx is None: + raise ValueError(f"Transaction not found: {request.transaction_id}") + del self.transactions[request.transaction_id] return commit.CommitResponse() def Rollback(self, request, context): + self._requests.append(request) return empty_pb2.Empty() def PartitionQuery(self, request, context): + self._requests.append(request) return spanner.PartitionResponse() def PartitionRead(self, request, context): + self._requests.append(request) return spanner.PartitionResponse() def BatchWrite(self, request, context): + self._requests.append(request) for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]: yield result diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index ae23533df7..f2dab9af06 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio + import unittest from google.cloud.spanner_admin_database_v1.types import spanner_database_admin @@ -28,7 +28,8 @@ Client, FixedSizePool, BatchCreateSessionsRequest, - ExecuteSqlRequest, CreateSessionRequest, + ExecuteSqlRequest, + GetSessionRequest, ) from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance @@ -124,9 +125,12 @@ def test_select1(self): self.assertEqual(1, row[0]) self.assertEqual(1, len(result_list)) requests = self.spanner_service.requests - self.assertEqual(2, len(requests)) + self.assertEqual(3, len(requests)) self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) - self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + # TODO: Optimize FixedSizePool so this GetSessionRequest is not executed + # every time a session is fetched. + self.assertTrue(isinstance(requests[1], GetSessionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) def test_create_table(self): database_admin_api = self.client.database_admin_api @@ -145,28 +149,3 @@ def test_create_table(self): ) operation = database_admin_api.update_database_ddl(request) operation.result(1) - - - def test_async_select1(self): - self._add_select1_result() - results = asyncio.run(self._async_select1()) - result_list = [] - for row in results.rows: - result_list.append(row) - self.assertEqual("1", row[0]) - self.assertEqual(1, len(result_list)) - requests = self.spanner_service.requests - self.assertEqual(3, len(requests)) - self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) - self.assertTrue(isinstance(requests[1], CreateSessionRequest)) - self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) - - async def _async_select1(self): - client = self.database.spanner_async_api - create_session_request = CreateSessionRequest(database=self._database.name) - session = await client.create_session(create_session_request) - execute_request = ExecuteSqlRequest(dict( - session=session.name, - sql="select 1", - )) - return await client.execute_sql(execute_request)