Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new-style wire protocol implementation for find*() methods #262

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from twisted.trial import unittest
from twisted.internet import defer

import txmongo
import txmongo.filter as qf
from pymongo.errors import OperationFailure
Expand Down Expand Up @@ -110,6 +111,9 @@ def test_Comment(self):

@defer.inlineCallbacks
def test_Snapshot(self):
ismaster = yield self.db.command('ismaster')
if ismaster['maxWireVersion'] >= 7:
raise unittest.SkipTest('snapshot option is only for MongoDB <=3.6')
yield self.__test_simple_filter(qf.snapshot(), "snapshot", True)

@defer.inlineCallbacks
Expand Down
6 changes: 4 additions & 2 deletions tests/test_replicaset.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_find_with_timeout(self):

yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()
yield self.__mongod[0].kill(signal.SIGSTOP)

while True:
try:
Expand All @@ -200,6 +200,7 @@ def test_find_with_timeout(self):
pass

finally:
yield self.__mongod[0].kill(signal.SIGCONT)
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)

Expand All @@ -213,7 +214,7 @@ def test_find_with_deadline(self):

yield conn.db.coll.insert({'x': 42}, safe=True)

yield self.__mongod[0].stop()
yield self.__mongod[0].kill(signal.SIGSTOP)

while True:
try:
Expand All @@ -225,6 +226,7 @@ def test_find_with_deadline(self):
pass

finally:
yield self.__mongod[0].kill(signal.SIGCONT)
yield conn.disconnect()
self.flushLoggedErrors(AutoReconnect)

Expand Down
211 changes: 172 additions & 39 deletions txmongo/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,54 +436,137 @@ def query():
new_kwargs = self._find_args_compat(*args, **kwargs)
return self.__real_find_with_cursor(**new_kwargs)

def __real_find_with_cursor(self, filter=None, projection=None, skip=0, limit=0, sort=None, batch_size=0,**kwargs):

if filter is None:
filter = SON()

if not isinstance(filter, dict):
raise TypeError("TxMongo: filter must be an instance of dict.")
if not isinstance(projection, (dict, list)) and projection is not None:
raise TypeError("TxMongo: projection must be an instance of dict or list.")
if not isinstance(skip, int):
raise TypeError("TxMongo: skip must be an instance of int.")
if not isinstance(limit, int):
raise TypeError("TxMongo: limit must be an instance of int.")
if not isinstance(batch_size, int):
raise TypeError("TxMongo: batch_size must be an instance of int.")

projection = self._normalize_fields_projection(projection)

filter = self.__apply_find_filter(filter, sort)
_MODIFIERS = SON([
('$query', 'filter'),
('$orderby', 'sort'),
('$hint', 'hint'),
('$comment', 'comment'),
('$maxScan', 'maxScan'),
('$maxTimeMS', 'maxTimeMS'),
('$max', 'max'),
('$min', 'min'),
('$returnKey', 'returnKey'),
('$showRecordId', 'showRecordId'),
('$showDiskLoc', 'showRecordId'), # <= MongoDB 3.0
('$snapshot', 'snapshot'), # <= MongoDB 4.0
])

@classmethod
def _gen_find_command(cls, coll, filter_with_modifiers, projection, skip, limit, batch_size, max_wire_version):
cmd = SON([("find", coll)])
if "$query" in filter_with_modifiers:
cmd.update([(cls._MODIFIERS[key], val) if key in cls._MODIFIERS else (key, val)
for key, val in filter_with_modifiers.items()])
if max_wire_version >= 7: # MongoDB 4.0+
cmd.pop('snapshot', None)
else:
cmd["filter"] = filter_with_modifiers

if projection:
cmd["projection"] = projection
if skip:
cmd["skip"] = skip
if limit:
cmd["limit"] = abs(limit)
if limit < 0:
cmd["singleBatch"] = True
cmd["batchSize"] = abs(limit)
if batch_size:
cmd["batchSize"] = batch_size

if '$explain' in filter_with_modifiers:
cmd.pop('$explain')
cmd = SON([('explain', cmd)])

return cmd

def __send_find_command(self, protocol, filter, projection, skip, limit, batch_size, as_class, flags, deadline):
codec_options = self.codec_options
if as_class is not None:
codec_options = codec_options._replace(document_class=as_class)

def after_reply(result, this_func, fetched=0):
try:
check_deadline(deadline)
except Exception:
cursor_id = result.get("cursor", {}).get("id")
if cursor_id:
kill = SON([
("killCursors", self.name),
("cursors", [cursor_id])
])
self.database.command(kill)
raise

if "cursor" not in result:
return [result], defer.succeed(([], None))
cursor = result["cursor"]

docs_key = "firstBatch"
if "nextBatch" in cursor:
docs_key = "nextBatch"

docs_count = len(cursor[docs_key])
if limit > 0:
docs_count = min(docs_count, limit - fetched)
fetched += docs_count
out = cursor[docs_key][:docs_count]

as_class = kwargs.get("as_class")
proto = self._database.connection.getprotocol()
if cursor["id"]:
if limit == 0:
to_fetch = 0 # no limit
if batch_size:
to_fetch = batch_size
elif limit < 0:
# We won't actually get here because MongoDB won't
# create a cursor when limit < 0
to_fetch = None
else:
to_fetch = limit - fetched
if to_fetch <= 0:
to_fetch = None # close cursor
elif batch_size:
to_fetch = min(batch_size, to_fetch)

def after_connection(protocol):
flags = kwargs.get("flags", 0)
if to_fetch is None:
# FIXME: extract this to a function
kill = SON([
("killCursors", self.name),
("cursors", [cursor["id"]])
])
self.database.command(kill)
return out, defer.succeed(([], None))

check_deadline(kwargs.pop("_deadline", None))
# FIXME: extract this to a function
get_more = SON([
("getMore", cursor["id"]),
("collection", self.name),
])
if batch_size:
get_more["batchSize"] = batch_size
next_reply = self.database._send_command_to_proto(protocol, get_more, codec_options=codec_options, flags=flags)
next_reply.addCallback(this_func, this_func, fetched)
return out, next_reply

if batch_size and limit:
n_to_return = min(batch_size,limit)
elif batch_size:
n_to_return = batch_size
else:
n_to_return = limit
return out, defer.succeed(([], None))

query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=n_to_return,
query=filter, fields=projection)
cmd = self._gen_find_command(self.name, filter, projection, skip, limit, batch_size, protocol.max_wire_version)
return self.database._send_command_to_proto(protocol, cmd, codec_options=codec_options, flags=flags)\
.addCallback(after_reply, after_reply)

deferred_query = protocol.send_QUERY(query)
deferred_query.addCallback(after_reply, protocol, after_reply)
return deferred_query

def __send_legacy_find(self, protocol, filter, projection, skip, limit, batch_size, as_class, deadline, kwargs):
# this_func argument is just a reference to after_reply function itself.
# after_reply can reference to itself directly but this will create a circular
# reference between closure and function object which will add unnecessary
# work for GC.
def after_reply(reply, protocol, this_func, fetched=0):
try:
check_deadline(deadline)
except Exception:
if reply.cursor_id:
protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
raise

documents = reply.documents
docs_count = len(documents)
Expand All @@ -500,21 +583,21 @@ def after_reply(reply, protocol, this_func, fetched=0):
if reply.cursor_id:
# please note that this will not be the case if batch_size = 1
# it is documented (parameter numberToReturn for OP_QUERY)
# https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#wire-op-query
# https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#wire-op-query
if limit == 0:
to_fetch = 0 # no limit
if batch_size:
to_fetch = batch_size
elif limit < 0:
# We won't actually get here because MongoDB won't
# create cursor when limit < 0
# create a cursor when limit < 0
to_fetch = None
else:
to_fetch = limit - fetched
if to_fetch <= 0:
to_fetch = None # close cursor
elif batch_size:
to_fetch = min(batch_size,to_fetch)
to_fetch = min(batch_size, to_fetch)

if to_fetch is None:
protocol.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
Expand All @@ -529,6 +612,56 @@ def after_reply(reply, protocol, this_func, fetched=0):

return out, defer.succeed(([], None))

flags = kwargs.get("flags", 0)

if batch_size and limit:
n_to_return = min(batch_size, limit)
elif batch_size:
n_to_return = batch_size
else:
n_to_return = limit

query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=n_to_return,
query=filter, fields=projection)

deferred_query = protocol.send_QUERY(query)
deferred_query.addCallback(after_reply, protocol, after_reply)
return deferred_query


def __real_find_with_cursor(self, filter=None, projection=None, skip=0, limit=0, sort=None, batch_size=0, **kwargs):

if filter is None:
filter = SON()

if not isinstance(filter, dict):
raise TypeError("TxMongo: filter must be an instance of dict.")
if not isinstance(projection, (dict, list)) and projection is not None:
raise TypeError("TxMongo: projection must be an instance of dict or list.")
if not isinstance(skip, int):
raise TypeError("TxMongo: skip must be an instance of int.")
if not isinstance(limit, int):
raise TypeError("TxMongo: limit must be an instance of int.")
if not isinstance(batch_size, int):
raise TypeError("TxMongo: batch_size must be an instance of int.")

projection = self._normalize_fields_projection(projection)

filter = self.__apply_find_filter(filter, sort)

as_class = kwargs.get("as_class")
proto = self._database.connection.getprotocol()

deadline = kwargs.pop("_deadline", None)

def after_connection(protocol):
check_deadline(deadline)
if protocol.max_wire_version < 4:
return self.__send_legacy_find(protocol, filter, projection, skip, limit, batch_size, as_class, deadline, kwargs)
return self.__send_find_command(protocol, filter, projection, skip, limit, batch_size, as_class,
kwargs.get("flags", 0), deadline)

proto.addCallback(after_connection)
return proto

Expand Down
20 changes: 10 additions & 10 deletions txmongo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def _initializeProto(self, proto):
slaveok = uri_options['readpreference'] not in _PRIMARY_READ_PREFERENCES

try:
if not slaveok:
# Update our server configuration. This may disconnect if the node
# is not a master.
yield self.configure(proto)
# Update our server configuration. This may disconnect if the node
# is not a master and slaveok is not set
yield self.configure(proto, slaveok)

yield self._auth_proto(proto)
self.setInstance(instance=proto)
Expand All @@ -77,7 +76,7 @@ def __send_ismaster(proto, **kwargs):
return proto.send_QUERY(query)

@defer.inlineCallbacks
def configure(self, proto):
def configure(self, proto, slaveok):
"""
Configures the protocol using the information gathered from the
remote Mongo instance. Such information may contain the max
Expand Down Expand Up @@ -134,11 +133,12 @@ def configure(self, proto):
if host not in self.__allnodes:
self.__allnodes.append(host)

# Check if this node is the master.
ismaster = config.get("ismaster")
if not ismaster:
msg = "TxMongo: MongoDB host `%s` is not master." % config.get('me')
raise AutoReconnect(msg)
if not slaveok:
# Check if this node is the master.
ismaster = config.get("ismaster")
if not ismaster:
msg = "TxMongo: MongoDB host `%s` is not master." % config.get('me')
raise AutoReconnect(msg)

def clientConnectionFailed(self, connector, reason):
self.instance = None
Expand Down
Loading