Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-3064 Add typings to test package #844

Merged
merged 20 commits into from
Feb 8, 2022
2 changes: 1 addition & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
pip install -e ".[zstd, srv]"
- name: Run mypy
run: |
mypy --install-types --non-interactive bson gridfs tools
mypy --install-types --non-interactive bson gridfs tools test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to separate the test/ type checking so that we can globally ignore certain error codes in the test/ package? Eg:

mypy --install-types --non-interactive bson gridfs tools
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment test

Disabling some checks could make the tests easier to maintain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true

[mypy-gevent.*]
ignore_missing_imports = True

[mypy-kerberos.*]
ignore_missing_imports = True

Expand All @@ -29,5 +32,11 @@ ignore_missing_imports = True
[mypy-snappy.*]
ignore_missing_imports = True

[mypy-test.*]
allow_redefinition = true
ShaneHarvey marked this conversation as resolved.
Show resolved Hide resolved

[mypy-winkerberos.*]
ignore_missing_imports = True

[mypy-xmlrunner.*]
ignore_missing_imports = True
5 changes: 2 additions & 3 deletions pymongo/socket_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

import errno
import select
import socket
import sys
from typing import Any, Optional
from typing import Any, Optional, Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union is now unused.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


# PYTHON-2320: Jython does not fully support poll on SSL sockets,
# https://bugs.jython.org/issue2900
Expand All @@ -43,7 +42,7 @@ def __init__(self) -> None:
else:
self._poller = None

def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[Union[float, int]] = 0) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timeout should be float. The float type supports int automatically. Also is Optional needed? Do we ever pass None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call it with None explicitly in test_pooling.py, and there is internal handling for None. I'll change it to Optional[float]

"""Select for reads or writes with a timeout in seconds (or None).

Returns True if the socket is readable/writable, False on timeout.
Expand Down
4 changes: 2 additions & 2 deletions pymongo/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Type aliases used by PyMongo"""
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
Tuple, Type, TypeVar, Union)
Sequence, Tuple, Type, TypeVar, Union)

if TYPE_CHECKING:
from bson.raw_bson import RawBSONDocument
Expand All @@ -25,5 +25,5 @@
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
_Pipeline = List[Mapping[str, Any]]
_Pipeline = Sequence[Mapping[str, Any]]
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])
33 changes: 24 additions & 9 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from contextlib import contextmanager
from functools import wraps
from typing import Dict, no_type_check
from unittest import SkipTest

import pymongo
Expand All @@ -48,7 +49,9 @@
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.database import Database
from pymongo.hello import HelloCompat
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl
from pymongo.uri_parser import parse_uri
Expand Down Expand Up @@ -86,7 +89,7 @@
os.path.join(CERT_PATH, 'client.pem'))
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))

TLS_OPTIONS = dict(tls=True)
TLS_OPTIONS: Dict = dict(tls=True)
if CLIENT_PEM:
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
if CA_PEM:
Expand All @@ -102,13 +105,13 @@
# Remove after PYTHON-2712
from pymongo import pool
pool._MOCK_SERVICE_ID = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
Expand Down Expand Up @@ -184,6 +187,7 @@ def enable(self):
def __enter__(self):
self.enable()

@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
Expand Down Expand Up @@ -224,6 +228,8 @@ def _all_users(db):


class ClientContext(object):
client: MongoClient

MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI

def __init__(self):
Expand All @@ -247,9 +253,9 @@ def __init__(self):
self.tls = False
self.tlsCertificateKeyFile = False
self.server_is_resolvable = is_server_resolvable()
self.default_client_options = {}
self.default_client_options: Dict = {}
self.sessions_enabled = False
self.client = None
self.client = None # type: ignore
self.conn_lock = threading.Lock()
self.is_data_lake = False
self.load_balancer = TEST_LOADBALANCER
Expand Down Expand Up @@ -340,6 +346,7 @@ def _init_client(self):
try:
self.cmd_line = self.client.admin.command('getCmdLineOpts')
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
# Unauthorized.
Expand Down Expand Up @@ -418,6 +425,7 @@ def _init_client(self):
else:
self.server_parameters = self.client.admin.command(
'getParameter', '*')
assert self.cmd_line is not None
if 'enableTestCommands=1' in self.cmd_line['argv']:
self.test_commands_enabled = True
elif 'parsed' in self.cmd_line:
Expand All @@ -436,7 +444,8 @@ def _init_client(self):
self.mongoses.append(address)
if not self.serverless:
# Check for another mongos on the next port.
next_address = address[0], address[1] + 1
assert address is not None
next_address = address[0], address[1] + 1
mongos_client = self._connect(
*next_address, **self.default_client_options)
if mongos_client:
Expand Down Expand Up @@ -479,7 +488,7 @@ def has_secondaries(self):
@property
def storage_engine(self):
try:
return self.server_status.get("storageEngine", {}).get("name")
return self.server_status.get("storageEngine", {}).get("name") # type: ignore[union-attr]
except AttributeError:
# Raised if self.server_status is None.
return None
Expand All @@ -496,6 +505,7 @@ def _check_user_provided(self):
try:
return db_user in _all_users(client.admin)
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 18 or 'auth fails' in msg:
# Auth failed.
Expand All @@ -505,6 +515,7 @@ def _check_user_provided(self):

def _server_started_with_auth(self):
# MongoDB >= 2.0
assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
parsed = self.cmd_line['parsed']
# MongoDB >= 2.6
Expand All @@ -525,6 +536,7 @@ def _server_started_with_ipv6(self):
if not socket.has_ipv6:
return False

assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
return False
Expand Down Expand Up @@ -642,7 +654,7 @@ def supports_secondary_read_pref(self):
if self.has_secondaries:
return True
if self.is_mongos:
shard = self.client.config.shards.find_one()['host']
shard = self.client.config.shards.find_one()['host'] # type: ignore[index]
ShaneHarvey marked this conversation as resolved.
Show resolved Hide resolved
num_members = shard.count(',') + 1
return num_members > 1
return False
Expand Down Expand Up @@ -932,6 +944,9 @@ def fail_point(self, command_args):

class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""
client: MongoClient
db: Database
credentials: Dict[str, str]

@classmethod
@client_context.require_connection
Expand Down Expand Up @@ -1073,7 +1088,7 @@ def run(self, test):


if HAVE_XML:
class PymongoXMLTestRunner(XMLTestRunner):
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
def run(self, test):
setup()
result = super(PymongoXMLTestRunner, self).run(test)
Expand Down
1 change: 1 addition & 0 deletions test/auth_aws/test_auth_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


class TestAuthAWS(unittest.TestCase):
uri: str

@classmethod
def setUpClass(cls):
Expand Down
6 changes: 6 additions & 0 deletions test/mockupdb/test_cursor_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@


class TestCursorNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down Expand Up @@ -69,6 +72,9 @@ def op():


class TestKillCursorsNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_getmore_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_getmore_sharded(self):
servers = [MockupDB(), MockupDB()]

# Collect queries to either server in one queue.
q = Queue()
q: Queue = Queue()
for server in servers:
server.subscribe(q.put)
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',
Expand Down
8 changes: 4 additions & 4 deletions test/mockupdb/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def respond(r):
ServerApiVersion.V1))}
client = MongoClient("mongodb://"+primary.address_string,
appname='my app', # For _check_handshake_data()
**dict([k_map.get((k, v), (k, v)) for k, v
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
in kwargs.items()]))

self.addCleanup(client.close)

# We have an autoresponder luckily, so no need for `go()`.
assert client.db.command(hello)

# We do this checking here rather than in the autoresponder `respond()`
# because it runs in another Python thread so there are some funky things
# with error handling within that thread, and we want to be able to use
# with error handling within that thread, and we want to be able to use
# self.assertRaises().
self.handshake_req.assert_matches(protocol(hello, **kwargs))
self.handshake_req.assert_matches(protocol(hello, **kwargs)) # type: ignore[attr-defined]
_check_handshake_data(self.handshake_req)


Expand Down
4 changes: 2 additions & 2 deletions test/mockupdb/test_mixed_version_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_server(self, upgrade):
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()

# Collect queries to either server in one queue.
self.q = Queue()
self.q: Queue = Queue()
for server in self.mongos_old, self.mongos_new:
server.subscribe(self.q.put)
server.autoresponds('getlasterror')
Expand Down Expand Up @@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
def test(self):
self.setup_server(upgrade)
start = time.time()
servers_used = set()
servers_used: set = set()
while len(servers_used) < 2:
go(upgrade.function, self.client)
request = self.q.get(timeout=1)
Expand Down
2 changes: 2 additions & 0 deletions test/mockupdb/test_op_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@


class TestOpMsg(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
Expand Down
5 changes: 4 additions & 1 deletion test/mockupdb/test_op_msg_read_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import itertools
from typing import Any

from mockupdb import MockupDB, going, CommandBase
from pymongo import MongoClient, ReadPreference
Expand All @@ -27,6 +28,8 @@

class OpMsgReadPrefBase(unittest.TestCase):
single_mongod = False
primary: MockupDB
secondary: MockupDB

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -142,7 +145,7 @@ def test(self):
tag_sets=None)

client = self.setup_client(read_preference=pref)

expected_pref: Any
if operation.op_type == 'always-use-secondary':
expected_server = self.secondary
expected_pref = ReadPreference.SECONDARY
Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_query_read_pref_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_query_and_read_mode_sharded_op_msg(self):
for pref in read_prefs:
collection = client.db.get_collection('test',
read_preference=pref)
cursor = collection.find(query.copy())
cursor = collection.find(query.copy()) # type: ignore[attr-defined]
with going(next, cursor):
request = server.receives()
# Command is not nested in $query.
Expand Down
Loading