Skip to content

Commit

Permalink
chore: remove async + add transaction handling
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite committed Dec 4, 2024
1 parent 95b6cd6 commit 39a11d0
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 68 deletions.
26 changes: 1 addition & 25 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import google.auth.credentials
from google.api_core.retry import Retry
from google.api_core.retry import if_exception_type
from google.auth.aio.credentials import AnonymousCredentials
from google.cloud.exceptions import NotFound
from google.api_core.exceptions import Aborted
from google.api_core import gapic_v1
Expand All @@ -42,7 +41,7 @@
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1.transaction import BatchTransactionId
from google.cloud.spanner_v1 import ExecuteSqlRequest, SpannerAsyncClient
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import Type
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import TransactionSelector
Expand Down Expand Up @@ -144,7 +143,6 @@ class Database(object):
"""

_spanner_api: SpannerClient = None
_spanner_async_api: SpannerAsyncClient = None

def __init__(
self,
Expand Down Expand Up @@ -440,28 +438,6 @@ def spanner_api(self):
)
return self._spanner_api

@property
def spanner_async_api(self):
if self._spanner_async_api is None:
client_info = self._instance._client._client_info
client_options = self._instance._client._client_options
if self._instance.emulator_host is not None:
channel=grpc.aio.insecure_channel(target=self._instance.emulator_host)
transport = SpannerGrpcTransport(channel=channel)
self._spanner_async_api = SpannerAsyncClient(
client_info=client_info, transport=transport
)
return self._spanner_async_api
credentials = self._instance._client.credentials
if isinstance(credentials, google.auth.credentials.Scoped):
credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,))
self._spanner_async_api = SpannerAsyncClient(
credentials=credentials,
client_info=client_info,
client_options=client_options,
)
return self._spanner_async_api

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if isinstance(channel, grpc.Channel) or isinstance(channel, grpc.aio.Channel):
if isinstance(channel, grpc.Channel):
# Ignore credentials if a channel was passed.
credentials = None
self._ignore_credentials = True
Expand Down
83 changes: 70 additions & 13 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
from concurrent import futures

from google.protobuf import empty_pb2
Expand All @@ -31,22 +31,32 @@ def __init__(self):
self.results = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql.lower()] = result
self.results[sql.lower().strip()] = result

def get_result(self, sql: str) -> result_set.ResultSet:
result = self.results.get(sql.lower().strip())
if result is None:
raise ValueError(f"No result found for {sql}")
return result

def get_result_as_partial_result_sets(
self, sql: str
) -> [result_set.PartialResultSet]:
result: result_set.ResultSet = self.results.get(sql.lower())
if result is None:
return []
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
for row in result.rows:
if len(result.rows) == 0:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.values.extend(row)
partial.metadata = result.metadata
partials.append(partial)
else:
for row in result.rows:
partial = result_set.PartialResultSet()
if first:
partial.metadata = result.metadata
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
return partials


Expand All @@ -56,6 +66,8 @@ def __init__(self):
self._requests = []
self.session_counter = 0
self.sessions = {}
self.transaction_counter = 0
self.transactions = {}
self._mock_spanner = MockSpanner()

@property
Expand Down Expand Up @@ -93,18 +105,20 @@ def __create_session(self, database: str, session_template: spanner.Session):
return session

def GetSession(self, request, context):
self._requests.append(request)
return spanner.Session()

def ListSessions(self, request, context):
self._requests.append(request)
return [spanner.Session()]

def DeleteSession(self, request, context):
self._requests.append(request)
return empty_pb2.Empty()

def ExecuteSql(self, request, context):
self._requests.append(request)
result: result_set.ResultSet = self.mock_spanner.results.get(request.sql.lower())
return result
return result_set.ResultSet()

def ExecuteStreamingSql(self, request, context):
self._requests.append(request)
Expand All @@ -113,31 +127,74 @@ def ExecuteStreamingSql(self, request, context):
yield result

def ExecuteBatchDml(self, request, context):
return spanner.ExecuteBatchDmlResponse()
self._requests.append(request)
response = spanner.ExecuteBatchDmlResponse()
started_transaction = None
if not request.transaction.begin == transaction.TransactionOptions():
started_transaction = self.__create_transaction(
request.session, request.transaction.begin
)
first = True
for statement in request.statements:
result = self.mock_spanner.get_result(statement.sql)
if first and started_transaction is not None:
result = result_set.ResultSet(
self.mock_spanner.get_result(statement.sql)
)
result.metadata = result_set.ResultSetMetadata(result.metadata)
result.metadata.transaction = started_transaction
response.result_sets.append(result)
return response

def Read(self, request, context):
self._requests.append(request)
return result_set.ResultSet()

def StreamingRead(self, request, context):
self._requests.append(request)
for result in [result_set.PartialResultSet(), result_set.PartialResultSet()]:
yield result

def BeginTransaction(self, request, context):
return transaction.Transaction()
self._requests.append(request)
return self.__create_transaction(request.session, request.options)

def __create_transaction(
self, session: str, options: transaction.TransactionOptions
) -> transaction.Transaction:
session = self.sessions[session]
if session is None:
raise ValueError(f"Session not found: {session}")
self.transaction_counter += 1
id_bytes = bytes(
f"{session.name}/transactions/{self.transaction_counter}", "UTF-8"
)
transaction_id = base64.urlsafe_b64encode(id_bytes)
self.transactions[transaction_id] = options
return transaction.Transaction(dict(id=transaction_id))

def Commit(self, request, context):
self._requests.append(request)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
return commit.CommitResponse()

def Rollback(self, request, context):
self._requests.append(request)
return empty_pb2.Empty()

def PartitionQuery(self, request, context):
self._requests.append(request)
return spanner.PartitionResponse()

def PartitionRead(self, request, context):
self._requests.append(request)
return spanner.PartitionResponse()

def BatchWrite(self, request, context):
self._requests.append(request)
for result in [spanner.BatchWriteResponse(), spanner.BatchWriteResponse()]:
yield result

Expand Down
37 changes: 8 additions & 29 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

import unittest

from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
Expand All @@ -28,7 +28,8 @@
Client,
FixedSizePool,
BatchCreateSessionsRequest,
ExecuteSqlRequest, CreateSessionRequest,
ExecuteSqlRequest,
GetSessionRequest,
)
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
Expand Down Expand Up @@ -124,9 +125,12 @@ def test_select1(self):
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
self.assertEqual(2, len(requests))
self.assertEqual(3, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
# TODO: Optimize FixedSizePool so this GetSessionRequest is not executed
# every time a session is fetched.
self.assertTrue(isinstance(requests[1], GetSessionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))

def test_create_table(self):
database_admin_api = self.client.database_admin_api
Expand All @@ -145,28 +149,3 @@ def test_create_table(self):
)
operation = database_admin_api.update_database_ddl(request)
operation.result(1)


def test_async_select1(self):
self._add_select1_result()
results = asyncio.run(self._async_select1())
result_list = []
for row in results.rows:
result_list.append(row)
self.assertEqual("1", row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
self.assertEqual(3, len(requests))
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], CreateSessionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))

async def _async_select1(self):
client = self.database.spanner_async_api
create_session_request = CreateSessionRequest(database=self._database.name)
session = await client.create_session(create_session_request)
execute_request = ExecuteSqlRequest(dict(
session=session.name,
sql="select 1",
))
return await client.execute_sql(execute_request)

0 comments on commit 39a11d0

Please sign in to comment.