From b62c95fb941a0c78d5e7007e92ef17ace57abcb7 Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Mon, 24 Oct 2016 09:45:12 -0700 Subject: [PATCH] Switching _grpc_catch_rendezvous to a context manager. --- .../google/cloud/datastore/connection.py | 45 +++++++------------ datastore/unit_tests/test_connection.py | 19 +++++--- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/datastore/google/cloud/datastore/connection.py b/datastore/google/cloud/datastore/connection.py index 2d038eff28c2..aac5c85e0a88 100644 --- a/datastore/google/cloud/datastore/connection.py +++ b/datastore/google/cloud/datastore/connection.py @@ -14,6 +14,7 @@ """Connections to Google Cloud Datastore API servers.""" +import contextlib import os from google.rpc import status_pb2 @@ -237,8 +238,9 @@ def allocate_ids(self, project, request_pb): _datastore_pb2.AllocateIdsResponse) -def _grpc_catch_rendezvous(to_call, *args, **kwargs): - """Call a method/function and re-map gRPC exceptions. +@contextlib.contextmanager +def _grpc_catch_rendezvous(): + """Re-map gRPC exceptions that happen in context. .. _code.proto: https://github.com/googleapis/googleapis/blob/\ master/google/rpc/code.proto @@ -246,26 +248,9 @@ def _grpc_catch_rendezvous(to_call, *args, **kwargs): Remaps gRPC exceptions to the classes defined in :mod:`~google.cloud.exceptions` (according to the description in `code.proto`_). - - :type to_call: callable - :param to_call: Callable that makes a request which may raise a - :class:`~google.cloud.exceptions.GrpcRendezvous`. - - :type args: tuple - :param args: Positional arugments to the callable. - - :type kwargs: dict - :param kwargs: Keyword arguments to the callable. - - :rtype: object - :returns: The value returned from ``to_call``. - :raises: :class:`~google.cloud.exceptions.GrpcRendezvous` if one - is encountered that can't be re-mapped, otherwise maps - to a :class:`~google.cloud.exceptions.GoogleCloudError` - subclass. """ try: - return to_call(*args, **kwargs) + yield except exceptions.GrpcRendezvous as exc: error_code = exc.code() error_class = _GRPC_ERROR_MAPPING.get(error_code) @@ -331,8 +316,8 @@ def run_query(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return _grpc_catch_rendezvous( - self._stub.RunQuery, request_pb) + with _grpc_catch_rendezvous(): + return self._stub.RunQuery(request_pb) def begin_transaction(self, project, request_pb): """Perform a ``beginTransaction`` request. @@ -349,8 +334,8 @@ def begin_transaction(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return _grpc_catch_rendezvous( - self._stub.BeginTransaction, request_pb) + with _grpc_catch_rendezvous(): + return self._stub.BeginTransaction(request_pb) def commit(self, project, request_pb): """Perform a ``commit`` request. @@ -366,8 +351,8 @@ def commit(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return _grpc_catch_rendezvous( - self._stub.Commit, request_pb) + with _grpc_catch_rendezvous(): + return self._stub.Commit(request_pb) def rollback(self, project, request_pb): """Perform a ``rollback`` request. @@ -383,8 +368,8 @@ def rollback(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return _grpc_catch_rendezvous( - self._stub.Rollback, request_pb) + with _grpc_catch_rendezvous(): + return self._stub.Rollback(request_pb) def allocate_ids(self, project, request_pb): """Perform an ``allocateIds`` request. @@ -400,8 +385,8 @@ def allocate_ids(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return _grpc_catch_rendezvous( - self._stub.AllocateIds, request_pb) + with _grpc_catch_rendezvous(): + return self._stub.AllocateIds(request_pb) class Connection(connection_module.Connection): diff --git a/datastore/unit_tests/test_connection.py b/datastore/unit_tests/test_connection.py index c7577cedb568..973a3241506e 100644 --- a/datastore/unit_tests/test_connection.py +++ b/datastore/unit_tests/test_connection.py @@ -109,9 +109,9 @@ def test__request_not_200(self): @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') class Test__grpc_catch_rendezvous(unittest.TestCase): - def _callFUT(self, to_call, *args, **kwargs): + def _callFUT(self): from google.cloud.datastore.connection import _grpc_catch_rendezvous - return _grpc_catch_rendezvous(to_call, *args, **kwargs) + return _grpc_catch_rendezvous() @staticmethod def _fake_method(exc, result=None): @@ -122,7 +122,8 @@ def _fake_method(exc, result=None): def test_success(self): expected = object() - result = self._callFUT(self._fake_method, None, expected) + with self._callFUT(): + result = self._fake_method(None, expected) self.assertIs(result, expected) def test_failure_aborted(self): @@ -135,7 +136,8 @@ def test_failure_aborted(self): exc_state = _RPCState((), None, None, StatusCode.ABORTED, details) exc = GrpcRendezvous(exc_state, None, None, None) with self.assertRaises(Conflict): - self._callFUT(self._fake_method, exc) + with self._callFUT(): + self._fake_method(exc) def test_failure_invalid_argument(self): from grpc import StatusCode @@ -149,7 +151,8 @@ def test_failure_invalid_argument(self): StatusCode.INVALID_ARGUMENT, details) exc = GrpcRendezvous(exc_state, None, None, None) with self.assertRaises(BadRequest): - self._callFUT(self._fake_method, exc) + with self._callFUT(): + self._fake_method(exc) def test_failure_cancelled(self): from grpc import StatusCode @@ -159,12 +162,14 @@ def test_failure_cancelled(self): exc_state = _RPCState((), None, None, StatusCode.CANCELLED, None) exc = GrpcRendezvous(exc_state, None, None, None) with self.assertRaises(GrpcRendezvous): - self._callFUT(self._fake_method, exc) + with self._callFUT(): + self._fake_method(exc) def test_commit_failure_non_grpc_err(self): exc = RuntimeError('Not a gRPC error') with self.assertRaises(RuntimeError): - self._callFUT(self._fake_method, exc) + with self._callFUT(): + self._fake_method(exc) class Test_DatastoreAPIOverGRPC(unittest.TestCase):