Skip to content

Commit

Permalink
Merge pull request #1787 from dhermes/fix-1763
Browse files Browse the repository at this point in the history
Reducing limit or the life of a datastore query iterator.
  • Loading branch information
dhermes committed May 18, 2016
2 parents b449451 + 74026f4 commit d956136
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 33 deletions.
5 changes: 5 additions & 0 deletions gcloud/datastore/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def run_query(self, project, query_pb, namespace=None,
:param transaction_id: If passed, make the request in the scope of
the given transaction. Incompatible with
``eventual==True``.
:rtype: tuple
:returns: Four-tuple containing the entities returned,
the end cursor of the query, a ``more_results``
enum and a count of the number of skipped results.
"""
request = _datastore_pb2.RunQueryRequest()
_set_read_options(request, eventual, transaction_id)
Expand Down
32 changes: 16 additions & 16 deletions gcloud/datastore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,7 @@ class Iterator(object):
:param limit: (Optional) Limit the number of results returned.
:type offset: integer
:param offset: (Optional) Defaults to 0. Offset used to begin
a query.
:param offset: (Optional) Offset used to begin a query.
:type start_cursor: bytes
:param start_cursor: (Optional) Cursor to begin paging through
Expand All @@ -380,9 +379,10 @@ class Iterator(object):
_FINISHED = (
_query_pb2.QueryResultBatch.NO_MORE_RESULTS,
_query_pb2.QueryResultBatch.MORE_RESULTS_AFTER_LIMIT,
_query_pb2.QueryResultBatch.MORE_RESULTS_AFTER_CURSOR,
)

def __init__(self, query, client, limit=None, offset=0,
def __init__(self, query, client, limit=None, offset=None,
start_cursor=None, end_cursor=None):
self._query = query
self._client = client
Expand All @@ -391,6 +391,7 @@ def __init__(self, query, client, limit=None, offset=0,
self._start_cursor = start_cursor
self._end_cursor = end_cursor
self._page = self._more_results = None
self._skipped_results = None

def next_page(self):
"""Fetch a single "page" of query results.
Expand All @@ -413,7 +414,8 @@ def next_page(self):
if self._limit is not None:
pb.limit.value = self._limit

pb.offset = self._offset
if self._offset is not None:
pb.offset = self._offset

transaction = self._client.current_transaction

Expand All @@ -423,16 +425,8 @@ def next_page(self):
namespace=self._query.namespace,
transaction_id=transaction and transaction.id,
)
# NOTE: `query_results` contains an extra value that we don't use,
# namely `skipped_results`.
#
# NOTE: The value of `more_results` is not currently useful because
# the back-end always returns an enum
# value of MORE_RESULTS_AFTER_LIMIT even if there are no more
# results. See
# https://github.com/GoogleCloudPlatform/gcloud-python/issues/280
# for discussion.
entity_pbs, cursor_as_bytes, more_results_enum = query_results[:3]
(entity_pbs, cursor_as_bytes,
more_results_enum, self._skipped_results) = query_results

if cursor_as_bytes == b'':
self._start_cursor = None
Expand All @@ -457,13 +451,19 @@ def __iter__(self):
:rtype: sequence of :class:`gcloud.datastore.entity.Entity`
"""
self.next_page()
while True:
self.next_page()
for entity in self._page:
yield entity
if not self._more_results:
break
self.next_page()
num_results = len(self._page)
if self._limit is not None:
self._limit -= num_results
if self._offset is not None and self._skipped_results is not None:
# NOTE: The offset goes down relative to the location
# because we are updating the cursor each time.
self._offset -= self._skipped_results


def _pb_from_query(query):
Expand Down
101 changes: 84 additions & 17 deletions gcloud/datastore/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,23 +345,31 @@ def _getTargetClass(self):
def _makeOne(self, *args, **kw):
return self._getTargetClass()(*args, **kw)

def _addQueryResults(self, connection, cursor=_END, more=False):
def _addQueryResults(self, connection, cursor=_END, more=False,
skipped_results=None, no_entity=False):
from gcloud.datastore._generated import entity_pb2
from gcloud.datastore._generated import query_pb2
from gcloud.datastore.helpers import _new_value_pb

MORE = query_pb2.QueryResultBatch.NOT_FINISHED
NO_MORE = query_pb2.QueryResultBatch.MORE_RESULTS_AFTER_LIMIT
if more:
more_enum = query_pb2.QueryResultBatch.NOT_FINISHED
else:
more_enum = query_pb2.QueryResultBatch.MORE_RESULTS_AFTER_LIMIT
_ID = 123
entity_pb = entity_pb2.Entity()
entity_pb.key.partition_id.project_id = self._PROJECT
path_element = entity_pb.key.path.add()
path_element.kind = self._KIND
path_element.id = _ID
value_pb = _new_value_pb(entity_pb, 'foo')
value_pb.string_value = u'Foo'
if no_entity:
entities = []
else:
entity_pb = entity_pb2.Entity()
entity_pb.key.partition_id.project_id = self._PROJECT
path_element = entity_pb.key.path.add()
path_element.kind = self._KIND
path_element.id = _ID
value_pb = _new_value_pb(entity_pb, 'foo')
value_pb.string_value = u'Foo'
entities = [entity_pb]

connection._results.append(
([entity_pb], cursor, MORE if more else NO_MORE))
(entities, cursor, more_enum, skipped_results))

def _makeClient(self, connection=None):
if connection is None:
Expand All @@ -374,7 +382,8 @@ def test_ctor_defaults(self):
iterator = self._makeOne(query, connection)
self.assertTrue(iterator._query is query)
self.assertEqual(iterator._limit, None)
self.assertEqual(iterator._offset, 0)
self.assertEqual(iterator._offset, None)
self.assertEqual(iterator._skipped_results, None)

def test_ctor_explicit(self):
client = self._makeClient()
Expand All @@ -392,6 +401,7 @@ def test_next_page_no_cursors_no_more(self):
self._addQueryResults(connection, cursor=b'')
iterator = self._makeOne(query, client)
entities, more_results, cursor = iterator.next_page()
self.assertEqual(iterator._skipped_results, None)

self.assertEqual(cursor, None)
self.assertFalse(more_results)
Expand All @@ -415,13 +425,16 @@ def test_next_page_no_cursors_no_more_w_offset_and_limit(self):
connection = _Connection()
client = self._makeClient(connection)
query = _Query(client, self._KIND, self._PROJECT, self._NAMESPACE)
self._addQueryResults(connection, cursor=b'')
skipped_results = object()
self._addQueryResults(connection, cursor=b'',
skipped_results=skipped_results)
iterator = self._makeOne(query, client, 13, 29)
entities, more_results, cursor = iterator.next_page()

self.assertEqual(cursor, None)
self.assertFalse(more_results)
self.assertFalse(iterator._more_results)
self.assertEqual(iterator._skipped_results, skipped_results)
self.assertEqual(len(entities), 1)
self.assertEqual(entities[0].key.path,
[{'kind': self._KIND, 'id': self._ID}])
Expand Down Expand Up @@ -453,6 +466,7 @@ def test_next_page_w_cursors_w_more(self):
self.assertEqual(cursor, urlsafe_b64encode(self._END))
self.assertTrue(more_results)
self.assertTrue(iterator._more_results)
self.assertEqual(iterator._skipped_results, None)
self.assertEqual(iterator._end_cursor, None)
self.assertEqual(urlsafe_b64decode(iterator._start_cursor), self._END)
self.assertEqual(len(entities), 1)
Expand All @@ -476,8 +490,8 @@ def test_next_page_w_cursors_w_bogus_more(self):
client = self._makeClient(connection)
query = _Query(client, self._KIND, self._PROJECT, self._NAMESPACE)
self._addQueryResults(connection, cursor=self._END, more=True)
epb, cursor, _ = connection._results.pop()
connection._results.append((epb, cursor, 4)) # invalid enum
epb, cursor, _, _ = connection._results.pop()
connection._results.append((epb, cursor, 5, None)) # invalid enum
iterator = self._makeOne(query, client)
self.assertRaises(ValueError, iterator.next_page)

Expand Down Expand Up @@ -523,9 +537,7 @@ def test___iter___w_more(self):
[{'kind': self._KIND, 'id': self._ID}])
self.assertEqual(entities[1]['foo'], u'Foo')
qpb1 = _pb_from_query(query)
qpb1.offset = 0
qpb2 = _pb_from_query(query)
qpb2.offset = 0
qpb2.start_cursor = self._END
EXPECTED1 = {
'project': self._PROJECT,
Expand All @@ -543,6 +555,61 @@ def test___iter___w_more(self):
self.assertEqual(connection._called_with[0], EXPECTED1)
self.assertEqual(connection._called_with[1], EXPECTED2)

def test___iter___w_limit(self):
from gcloud.datastore.query import _pb_from_query

connection = _Connection()
client = self._makeClient(connection)
query = _Query(client, self._KIND, self._PROJECT, self._NAMESPACE)
skip1 = 4
skip2 = 9
self._addQueryResults(connection, more=True, skipped_results=skip1,
no_entity=True)
self._addQueryResults(connection, more=True, skipped_results=skip2)
self._addQueryResults(connection)
offset = skip1 + skip2
iterator = self._makeOne(query, client, limit=2, offset=offset)
entities = list(iterator)

self.assertFalse(iterator._more_results)
self.assertEqual(len(entities), 2)
for entity in entities:
self.assertEqual(
entity.key.path,
[{'kind': self._KIND, 'id': self._ID}])
qpb1 = _pb_from_query(query)
qpb1.limit.value = 2
qpb1.offset = offset
qpb2 = _pb_from_query(query)
qpb2.start_cursor = self._END
qpb2.limit.value = 2
qpb2.offset = offset - skip1
qpb3 = _pb_from_query(query)
qpb3.start_cursor = self._END
qpb3.limit.value = 1
EXPECTED1 = {
'project': self._PROJECT,
'query_pb': qpb1,
'namespace': self._NAMESPACE,
'transaction_id': None,
}
EXPECTED2 = {
'project': self._PROJECT,
'query_pb': qpb2,
'namespace': self._NAMESPACE,
'transaction_id': None,
}
EXPECTED3 = {
'project': self._PROJECT,
'query_pb': qpb3,
'namespace': self._NAMESPACE,
'transaction_id': None,
}
self.assertEqual(len(connection._called_with), 3)
self.assertEqual(connection._called_with[0], EXPECTED1)
self.assertEqual(connection._called_with[1], EXPECTED2)
self.assertEqual(connection._called_with[2], EXPECTED3)


class Test__pb_from_query(unittest2.TestCase):

Expand Down

0 comments on commit d956136

Please sign in to comment.