diff --git a/gcloud/datastore/client.py b/gcloud/datastore/client.py index ae07af1d2746..3e8a27b64efa 100644 --- a/gcloud/datastore/client.py +++ b/gcloud/datastore/client.py @@ -189,7 +189,7 @@ def _push_batch(self, batch): :type batch: :class:`gcloud.datastore.batch.Batch`, or an object implementing its API. - :param batch: newly-active batch/batch/transaction. + :param batch: newly-active batch/transaction. """ self._batch_stack.push(batch) @@ -211,7 +211,7 @@ def current_batch(self): :rtype: :class:`gcloud.datastore.batch.Batch`, or an object implementing its API, or ``NoneType`` (if no batch is active). - :returns: The batch/transaction at the toop of the batch stack. + :returns: The batch/transaction at the top of the batch stack. """ return self._batch_stack.top @@ -222,7 +222,7 @@ def current_transaction(self): :rtype: :class:`gcloud.datastore.transaction.Transaction`, or an object implementing its API, or ``NoneType`` (if no transaction is active). - :returns: The transaction at the toop of the batch stack. + :returns: The transaction at the top of the batch stack. """ transaction = self.current_batch if isinstance(transaction, Transaction): diff --git a/gcloud/storage/batch.py b/gcloud/storage/batch.py index 8404bef1c587..94e3269c9de4 100644 --- a/gcloud/storage/batch.py +++ b/gcloud/storage/batch.py @@ -26,14 +26,10 @@ import six -from gcloud._helpers import _LocalStack from gcloud.exceptions import make_exception from gcloud.storage.connection import Connection -_BATCHES = _LocalStack() - - class MIMEApplicationHTTP(MIMEApplication): """MIME type for ``application/http``. @@ -244,19 +240,21 @@ def finish(self): url = '%s/batch' % self.API_BASE_URL - response, content = self._client.connection._make_request( + # Use the private ``_connection`` rather than the public + # ``.connection``, since the public connection may be this + # current batch. + response, content = self._client._connection._make_request( 'POST', url, data=body, headers=headers) responses = list(_unpack_batch_response(response, content)) self._finish_futures(responses) return responses - @staticmethod - def current(): + def current(self): """Return the topmost batch, or None.""" - return _BATCHES.top + return self._client.current_batch def __enter__(self): - _BATCHES.push(self) + self._client._push_batch(self) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -264,7 +262,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: self.finish() finally: - _BATCHES.pop() + self._client._pop_batch() def _generate_faux_mime_message(parser, response, content): diff --git a/gcloud/storage/blob.py b/gcloud/storage/blob.py index d1a6b33dc3c2..6eef1af093e0 100644 --- a/gcloud/storage/blob.py +++ b/gcloud/storage/blob.py @@ -193,7 +193,7 @@ def generate_signed_url(self, expiration, method='GET', if credentials is None: client = self._require_client(client) - credentials = client.connection.credentials + credentials = client._connection.credentials return generate_signed_url( credentials, resource=resource, @@ -291,7 +291,13 @@ def download_to_file(self, file_obj, client=None): headers['Range'] = 'bytes=0-%d' % (self.chunk_size - 1,) request = http_wrapper.Request(download_url, 'GET', headers) - download.InitializeDownload(request, client.connection.http) + # Use the private ``_connection`` rather than the public + # ``.connection``, since the public connection may be a batch. A + # batch wraps a client's connection, but does not store the `http` + # object. The rest (API_BASE_URL and build_api_url) are also defined + # on the Batch class, but we just use the wrapped connection since + # it has all three (http, API_BASE_URL and build_api_url). + download.InitializeDownload(request, client._connection.http) # Should we be passing callbacks through from caller? We can't # pass them as None, because apitools wants to print to the console @@ -379,7 +385,13 @@ def upload_from_file(self, file_obj, rewind=False, size=None, determined """ client = self._require_client(client) - connection = client.connection + # Use the private ``_connection`` rather than the public + # ``.connection``, since the public connection may be a batch. A + # batch wraps a client's connection, but does not store the `http` + # object. The rest (API_BASE_URL and build_api_url) are also defined + # on the Batch class, but we just use the wrapped connection since + # it has all three (http, API_BASE_URL and build_api_url). + connection = client._connection content_type = (content_type or self._properties.get('contentType') or 'application/octet-stream') diff --git a/gcloud/storage/client.py b/gcloud/storage/client.py index 9cf3629f7a0a..45ea498f4a4c 100644 --- a/gcloud/storage/client.py +++ b/gcloud/storage/client.py @@ -15,6 +15,7 @@ """gcloud storage client for interacting with API.""" +from gcloud._helpers import _LocalStack from gcloud.client import JSONClient from gcloud.exceptions import NotFound from gcloud.iterator import Iterator @@ -45,6 +46,73 @@ class Client(JSONClient): _connection_class = Connection + def __init__(self, project=None, credentials=None, http=None): + self._connection = None + super(Client, self).__init__(project=project, credentials=credentials, + http=http) + self._batch_stack = _LocalStack() + + @property + def connection(self): + """Get connection or batch on the client. + + :rtype: :class:`gcloud.storage.connection.Connection` + :returns: The connection set on the client, or the batch + if one is set. + """ + if self.current_batch is not None: + return self.current_batch + else: + return self._connection + + @connection.setter + def connection(self, value): + """Set connection on the client. + + Intended to be used by constructor (since the base class calls) + self.connection = connection + Will raise if the connection is set more than once. + + :type value: :class:`gcloud.storage.connection.Connection` + :param value: The connection set on the client. + + :raises: :class:`ValueError` if connection has already been set. + """ + if self._connection is not None: + raise ValueError('Connection already set on client') + self._connection = value + + def _push_batch(self, batch): + """Push a batch onto our stack. + + "Protected", intended for use by batch context mgrs. + + :type batch: :class:`gcloud.storage.batch.Batch` + :param batch: newly-active batch + """ + self._batch_stack.push(batch) + + def _pop_batch(self): + """Pop a batch from our stack. + + "Protected", intended for use by batch context mgrs. + + :raises: IndexError if the stack is empty. + :rtype: :class:`gcloud.storage.batch.Batch` + :returns: the top-most batch/transaction, after removing it. + """ + return self._batch_stack.pop() + + @property + def current_batch(self): + """Currently-active batch. + + :rtype: :class:`gcloud.storage.batch.Batch` or ``NoneType`` (if + no batch is active). + :returns: The batch at the top of the batch stack. + """ + return self._batch_stack.top + def get_bucket(self, bucket_name): """Get a bucket by name. diff --git a/gcloud/storage/test_batch.py b/gcloud/storage/test_batch.py index 41a989727104..b08cffab91d3 100644 --- a/gcloud/storage/test_batch.py +++ b/gcloud/storage/test_batch.py @@ -84,6 +84,21 @@ def test_ctor_w_explicit_connection(self): self.assertEqual(len(batch._requests), 0) self.assertEqual(len(batch._target_objects), 0) + def test_current(self): + from gcloud.storage.client import Client + project = 'PROJECT' + credentials = _Credentials() + client = Client(project=project, credentials=credentials) + batch1 = self._makeOne(client) + self.assertTrue(batch1.current() is None) + + client._push_batch(batch1) + self.assertTrue(batch1.current() is batch1) + + batch2 = self._makeOne(client) + client._push_batch(batch2) + self.assertTrue(batch1.current() is batch2) + def test__make_request_GET_normal(self): from gcloud.storage.batch import _FutureDict URL = 'http://example.com/api' @@ -354,39 +369,31 @@ def test_finish_nonempty_non_multipart_response(self): batch._requests.append(('DELETE', URL, {}, None)) self.assertRaises(ValueError, batch.finish) - def test_current(self): - from gcloud.storage.batch import _BATCHES - klass = self._getTargetClass() - batch_top = object() - self.assertEqual(list(_BATCHES), []) - _BATCHES.push(batch_top) - self.assertTrue(klass.current() is batch_top) - _BATCHES.pop() - self.assertEqual(list(_BATCHES), []) - def test_as_context_mgr_wo_error(self): - from gcloud.storage.batch import _BATCHES + from gcloud.storage.client import Client URL = 'http://example.com/api' expected = _Response() expected['content-type'] = 'multipart/mixed; boundary="DEADBEEF="' http = _HTTP((expected, _THREE_PART_MIME_RESPONSE)) - connection = _Connection(http=http) - client = _Client(connection) + project = 'PROJECT' + credentials = _Credentials() + client = Client(project=project, credentials=credentials) + client._connection._http = http - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batch_stack), []) target1 = _MockObject() target2 = _MockObject() target3 = _MockObject() with self._makeOne(client) as batch: - self.assertEqual(list(_BATCHES), [batch]) + self.assertEqual(list(client._batch_stack), [batch]) batch._make_request('POST', URL, {'foo': 1, 'bar': 2}, target_object=target1) batch._make_request('PATCH', URL, {'bar': 3}, target_object=target2) batch._make_request('DELETE', URL, target_object=target3) - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batch_stack), []) self.assertEqual(len(batch._requests), 3) self.assertEqual(batch._requests[0][0], 'POST') self.assertEqual(batch._requests[1][0], 'PATCH') @@ -400,19 +407,23 @@ def test_as_context_mgr_wo_error(self): def test_as_context_mgr_w_error(self): from gcloud.storage.batch import _FutureDict - from gcloud.storage.batch import _BATCHES + from gcloud.storage.client import Client URL = 'http://example.com/api' http = _HTTP() connection = _Connection(http=http) + project = 'PROJECT' + credentials = _Credentials() + client = Client(project=project, credentials=credentials) + client._connection = connection - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batch_stack), []) target1 = _MockObject() target2 = _MockObject() target3 = _MockObject() try: - with self._makeOne(connection) as batch: - self.assertEqual(list(_BATCHES), [batch]) + with self._makeOne(client) as batch: + self.assertEqual(list(client._batch_stack), [batch]) batch._make_request('POST', URL, {'foo': 1, 'bar': 2}, target_object=target1) batch._make_request('PATCH', URL, {'bar': 3}, @@ -422,7 +433,7 @@ def test_as_context_mgr_w_error(self): except ValueError: pass - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batch_stack), []) self.assertEqual(len(http._requests), 0) self.assertEqual(len(batch._requests), 3) self.assertEqual(batch._target_objects, [target1, target2, target3]) @@ -597,4 +608,17 @@ class _MockObject(object): class _Client(object): def __init__(self, connection): - self.connection = connection + self._connection = connection + + +class _Credentials(object): + + _scopes = None + + @staticmethod + def create_scoped_required(): + return True + + def create_scoped(self, scope): + self._scopes = scope + return self diff --git a/gcloud/storage/test_blob.py b/gcloud/storage/test_blob.py index babbb036fe60..c85a62e4f94d 100644 --- a/gcloud/storage/test_blob.py +++ b/gcloud/storage/test_blob.py @@ -1082,4 +1082,8 @@ def __call__(self, *args, **kwargs): class _Client(object): def __init__(self, connection): - self.connection = connection + self._connection = connection + + @property + def connection(self): + return self._connection diff --git a/gcloud/storage/test_client.py b/gcloud/storage/test_client.py index d668f66c8410..7cf6caf9f096 100644 --- a/gcloud/storage/test_client.py +++ b/gcloud/storage/test_client.py @@ -33,6 +33,61 @@ def test_ctor_connection_type(self): client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) self.assertTrue(isinstance(client.connection, Connection)) self.assertTrue(client.connection.credentials is CREDENTIALS) + self.assertTrue(client.current_batch is None) + self.assertEqual(list(client._batch_stack), []) + + def test__push_batch_and__pop_batch(self): + from gcloud.storage.batch import Batch + + PROJECT = object() + CREDENTIALS = _Credentials() + + client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) + batch1 = Batch(client) + batch2 = Batch(client) + client._push_batch(batch1) + self.assertEqual(list(client._batch_stack), [batch1]) + self.assertTrue(client.current_batch is batch1) + client._push_batch(batch2) + self.assertTrue(client.current_batch is batch2) + # list(_LocalStack) returns in reverse order. + self.assertEqual(list(client._batch_stack), [batch2, batch1]) + self.assertTrue(client._pop_batch() is batch2) + self.assertEqual(list(client._batch_stack), [batch1]) + self.assertTrue(client._pop_batch() is batch1) + self.assertEqual(list(client._batch_stack), []) + + def test_connection_setter(self): + PROJECT = object() + CREDENTIALS = _Credentials() + client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) + client._connection = None # Unset the value from the constructor + client.connection = connection = object() + self.assertTrue(client._connection is connection) + + def test_connection_setter_when_set(self): + PROJECT = object() + CREDENTIALS = _Credentials() + client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) + self.assertRaises(ValueError, setattr, client, 'connection', None) + + def test_connection_getter_no_batch(self): + PROJECT = object() + CREDENTIALS = _Credentials() + client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) + self.assertTrue(client.connection is client._connection) + self.assertTrue(client.current_batch is None) + + def test_connection_getter_with_batch(self): + from gcloud.storage.batch import Batch + PROJECT = object() + CREDENTIALS = _Credentials() + client = self._makeOne(project=PROJECT, credentials=CREDENTIALS) + batch = Batch(client) + client._push_batch(batch) + self.assertTrue(client.connection is not client._connection) + self.assertTrue(client.connection is batch) + self.assertTrue(client.current_batch is batch) def test_get_bucket_miss(self): from gcloud.exceptions import NotFound diff --git a/system_tests/storage.py b/system_tests/storage.py index 67fb2fc59a09..045c5cb4653a 100644 --- a/system_tests/storage.py +++ b/system_tests/storage.py @@ -51,15 +51,9 @@ def setUp(self): self.case_buckets_to_delete = [] def tearDown(self): - with storage.Batch(CLIENT) as batch: - # Stop-gap measure to support batches during transation to - # to clients from implicit behavior. - batch_client = storage.Client( - project=CLIENT.project, - credentials=CLIENT.connection.credentials) - batch_client.connection = batch + with storage.Batch(CLIENT): for bucket_name in self.case_buckets_to_delete: - storage.Bucket(batch_client, name=bucket_name).delete() + storage.Bucket(CLIENT, name=bucket_name).delete() def test_create_bucket(self): new_bucket_name = 'a-new-bucket'