diff --git a/core/google/cloud/iterator.py b/core/google/cloud/iterator.py index 3581fed2601c3..b7652e647767e 100644 --- a/core/google/cloud/iterator.py +++ b/core/google/cloud/iterator.py @@ -45,40 +45,84 @@ def get_items_from_response(self, response): """ +import six + + class Iterator(object): """A generic class for iterating through Cloud JSON APIs list responses. :type client: :class:`google.cloud.client.Client` :param client: The client, which owns a connection to make requests. - :type path: string + :type path: str :param path: The path to query for the list of items. + :type page_token: str + :param page_token: (Optional) A token identifying a page in a result set. + + :type max_results: int + :param max_results: (Optional) The maximum number of results to fetch. + :type extra_params: dict or None :param extra_params: Extra query string parameters for the API call. """ PAGE_TOKEN = 'pageToken' - RESERVED_PARAMS = frozenset([PAGE_TOKEN]) + MAX_RESULTS = 'maxResults' + RESERVED_PARAMS = frozenset([PAGE_TOKEN, MAX_RESULTS]) - def __init__(self, client, path, extra_params=None): + def __init__(self, client, path, page_token=None, + max_results=None, extra_params=None): self.client = client self.path = path self.page_number = 0 - self.next_page_token = None + self.next_page_token = page_token + self.max_results = max_results + self.num_results = 0 self.extra_params = extra_params or {} reserved_in_use = self.RESERVED_PARAMS.intersection( self.extra_params) if reserved_in_use: raise ValueError(('Using a reserved parameter', reserved_in_use)) + self._curr_items = iter(()) def __iter__(self): - """Iterate through the list of items.""" - while self.has_next_page(): + """The :class:`Iterator` is an iterator.""" + return self + + def _update_items(self): + """Replace the current items iterator. + + Intended to be used when the current items iterator is exhausted. + + After replacing the iterator, consumes the first value to make sure + it is valid. + + :rtype: object + :returns: The first item in the next iterator. + :raises: :class:`~exceptions.StopIteration` if there is no next page. + """ + if self.has_next_page(): response = self.get_next_page_response() - for item in self.get_items_from_response(response): - yield item + items = self.get_items_from_response(response) + self._curr_items = iter(items) + return six.next(self._curr_items) + else: + raise StopIteration + + def next(self): + """Get the next value in the iterator.""" + try: + item = six.next(self._curr_items) + except StopIteration: + item = self._update_items() + + self.num_results += 1 + return item + + # Alias needed for Python 2/3 support. + __next__ = next def has_next_page(self): """Determines whether or not this iterator has more pages. @@ -89,6 +133,10 @@ def has_next_page(self): if self.page_number == 0: return True + if self.max_results is not None: + if self.num_results >= self.max_results: + return False + return self.next_page_token is not None def get_query_params(self): @@ -97,8 +145,11 @@ def get_query_params(self): :rtype: dict :returns: A dictionary of query parameters. """ - result = ({self.PAGE_TOKEN: self.next_page_token} - if self.next_page_token else {}) + result = {} + if self.next_page_token is not None: + result[self.PAGE_TOKEN] = self.next_page_token + if self.max_results is not None: + result[self.MAX_RESULTS] = self.max_results - self.num_results result.update(self.extra_params) return result @@ -123,6 +174,7 @@ def reset(self): """Resets the iterator to the beginning.""" self.page_number = 0 self.next_page_token = None + self.num_results = 0 def get_items_from_response(self, response): """Factory method called while iterating. This should be overridden. diff --git a/core/unit_tests/test_iterator.py b/core/unit_tests/test_iterator.py index ec823d9ccb222..44d02d30770e3 100644 --- a/core/unit_tests/test_iterator.py +++ b/core/unit_tests/test_iterator.py @@ -34,7 +34,21 @@ def test_ctor(self): self.assertEqual(iterator.page_number, 0) self.assertIsNone(iterator.next_page_token) + def test_constructor_w_extra_param_collision(self): + connection = _Connection() + client = _Client(connection) + PATH = '/foo' + extra_params = {'pageToken': 'val'} + self.assertRaises(ValueError, self._makeOne, client, PATH, + extra_params=extra_params) + def test___iter__(self): + iterator = self._makeOne(None, None) + self.assertIs(iter(iterator), iterator) + + def test_iterate(self): + import six + PATH = '/foo' KEY1 = 'key1' KEY2 = 'key2' @@ -42,13 +56,27 @@ def test___iter__(self): ITEMS = {KEY1: ITEM1, KEY2: ITEM2} def _get_items(response): - for item in response.get('items', []): - yield ITEMS[item['name']] - connection = _Connection({'items': [{'name': KEY1}, {'name': KEY2}]}) + return [ITEMS[item['name']] + for item in response.get('items', [])] + + connection = _Connection( + {'items': [{'name': KEY1}, {'name': KEY2}]}) client = _Client(connection) iterator = self._makeOne(client, PATH) iterator.get_items_from_response = _get_items - self.assertEqual(list(iterator), [ITEM1, ITEM2]) + self.assertEqual(iterator.num_results, 0) + + val1 = six.next(iterator) + self.assertEqual(val1, ITEM1) + self.assertEqual(iterator.num_results, 1) + + val2 = six.next(iterator) + self.assertEqual(val2, ITEM2) + self.assertEqual(iterator.num_results, 2) + + with self.assertRaises(StopIteration): + six.next(iterator) + kw, = connection._requested self.assertEqual(kw['method'], 'GET') self.assertEqual(kw['path'], PATH) @@ -79,6 +107,19 @@ def test_has_next_page_w_number_w_token(self): iterator.next_page_token = TOKEN self.assertTrue(iterator.has_next_page()) + def test_has_next_page_w_max_results_not_done(self): + iterator = self._makeOne(None, None, max_results=3, + page_token='definitely-not-none') + iterator.page_number = 1 + self.assertLess(iterator.num_results, iterator.max_results) + self.assertTrue(iterator.has_next_page()) + + def test_has_next_page_w_max_results_done(self): + iterator = self._makeOne(None, None, max_results=3) + iterator.page_number = 1 + iterator.num_results = iterator.max_results + self.assertFalse(iterator.has_next_page()) + def test_get_query_params_no_token(self): connection = _Connection() client = _Client(connection) @@ -96,6 +137,18 @@ def test_get_query_params_w_token(self): self.assertEqual(iterator.get_query_params(), {'pageToken': TOKEN}) + def test_get_query_params_w_max_results(self): + connection = _Connection() + client = _Client(connection) + path = '/foo' + max_results = 3 + iterator = self._makeOne(client, path, + max_results=max_results) + iterator.num_results = 1 + local_max = max_results - iterator.num_results + self.assertEqual(iterator.get_query_params(), + {'maxResults': local_max}) + def test_get_query_params_extra_params(self): connection = _Connection() client = _Client(connection) @@ -117,14 +170,6 @@ def test_get_query_params_w_token_and_extra_params(self): expected_query.update({'pageToken': TOKEN}) self.assertEqual(iterator.get_query_params(), expected_query) - def test_get_query_params_w_token_collision(self): - connection = _Connection() - client = _Client(connection) - PATH = '/foo' - extra_params = {'pageToken': 'val'} - self.assertRaises(ValueError, self._makeOne, client, PATH, - extra_params=extra_params) - def test_get_next_page_response_new_no_token_in_response(self): PATH = '/foo' TOKEN = 'token' diff --git a/resource_manager/google/cloud/resource_manager/client.py b/resource_manager/google/cloud/resource_manager/client.py index f34cf6eb208c1..80d7392bb9f66 100644 --- a/resource_manager/google/cloud/resource_manager/client.py +++ b/resource_manager/google/cloud/resource_manager/client.py @@ -168,14 +168,22 @@ class _ProjectIterator(Iterator): :type client: :class:`~google.cloud.resource_manager.client.Client` :param client: The client to use for making connections. + :type page_token: str + :param page_token: (Optional) A token identifying a page in a result set. + + :type max_results: int + :param max_results: (Optional) The maximum number of results to fetch. + :type extra_params: dict :param extra_params: (Optional) Extra query string parameters for the API call. """ - def __init__(self, client, extra_params=None): - super(_ProjectIterator, self).__init__(client=client, path='/projects', - extra_params=extra_params) + def __init__(self, client, page_token=None, + max_results=None, extra_params=None): + super(_ProjectIterator, self).__init__( + client=client, path='/projects', page_token=page_token, + max_results=max_results, extra_params=extra_params) def get_items_from_response(self, response): """Yield projects from response. diff --git a/storage/google/cloud/storage/bucket.py b/storage/google/cloud/storage/bucket.py index d1b83f8ce5a24..77f86106f4e7d 100644 --- a/storage/google/cloud/storage/bucket.py +++ b/storage/google/cloud/storage/bucket.py @@ -37,6 +37,12 @@ class _BlobIterator(Iterator): :type bucket: :class:`google.cloud.storage.bucket.Bucket` :param bucket: The bucket from which to list blobs. + :type page_token: str + :param page_token: (Optional) A token identifying a page in a result set. + + :type max_results: int + :param max_results: (Optional) The maximum number of results to fetch. + :type extra_params: dict or None :param extra_params: Extra query string parameters for the API call. @@ -44,7 +50,8 @@ class _BlobIterator(Iterator): :param client: Optional. The client to use for making connections. Defaults to the bucket's client. """ - def __init__(self, bucket, extra_params=None, client=None): + def __init__(self, bucket, page_token=None, max_results=None, + extra_params=None, client=None): if client is None: client = bucket.client self.bucket = bucket @@ -52,6 +59,7 @@ def __init__(self, bucket, extra_params=None, client=None): self._current_prefixes = None super(_BlobIterator, self).__init__( client=client, path=bucket.path + '/o', + page_token=page_token, max_results=max_results, extra_params=extra_params) def get_items_from_response(self, response): @@ -285,9 +293,6 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None, """ extra_params = {} - if max_results is not None: - extra_params['maxResults'] = max_results - if prefix is not None: extra_params['prefix'] = prefix @@ -303,13 +308,8 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None, extra_params['fields'] = fields result = self._iterator_class( - self, extra_params=extra_params, client=client) - # Page token must be handled specially since the base `Iterator` - # class has it as a reserved property. - if page_token is not None: - # pylint: disable=attribute-defined-outside-init - result.next_page_token = page_token - # pylint: enable=attribute-defined-outside-init + self, page_token=page_token, max_results=max_results, + extra_params=extra_params, client=client) return result def delete(self, force=False, client=None): diff --git a/storage/google/cloud/storage/client.py b/storage/google/cloud/storage/client.py index a18aa378e1d61..c5eb22158e072 100644 --- a/storage/google/cloud/storage/client.py +++ b/storage/google/cloud/storage/client.py @@ -256,9 +256,6 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None, """ extra_params = {'project': self.project} - if max_results is not None: - extra_params['maxResults'] = max_results - if prefix is not None: extra_params['prefix'] = prefix @@ -267,14 +264,10 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None, if fields is not None: extra_params['fields'] = fields - result = _BucketIterator(client=self, - extra_params=extra_params) - # Page token must be handled specially since the base `Iterator` - # class has it as a reserved property. - if page_token is not None: - # pylint: disable=attribute-defined-outside-init - result.next_page_token = page_token - # pylint: enable=attribute-defined-outside-init + result = _BucketIterator( + client=self, page_token=page_token, + max_results=max_results, extra_params=extra_params) + return result @@ -288,13 +281,22 @@ class _BucketIterator(Iterator): :type client: :class:`google.cloud.storage.client.Client` :param client: The client to use for making connections. + :type page_token: str + :param page_token: (Optional) A token identifying a page in a result set. + + :type max_results: int + :param max_results: (Optional) The maximum number of results to fetch. + :type extra_params: dict or ``NoneType`` :param extra_params: Extra query string parameters for the API call. """ - def __init__(self, client, extra_params=None): - super(_BucketIterator, self).__init__(client=client, path='/b', - extra_params=extra_params) + def __init__(self, client, page_token=None, + max_results=None, extra_params=None): + super(_BucketIterator, self).__init__( + client=client, path='/b', + page_token=page_token, max_results=max_results, + extra_params=extra_params) def get_items_from_response(self, response): """Factory method which yields :class:`.Bucket` items from a response.