Skip to content

Commit

Permalink
PYTHON-3064 Add typings to test package (mongodb#844)
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 authored and juliusgeo committed Apr 4, 2022
1 parent 178a943 commit ee564c6
Show file tree
Hide file tree
Showing 67 changed files with 542 additions and 261 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ jobs:
pip install -e ".[zstd, srv]"
- name: Run mypy
run: |
mypy --install-types --non-interactive bson gridfs tools
mypy --install-types --non-interactive bson gridfs tools pymongo
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test
2 changes: 1 addition & 1 deletion bson/son.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# This is essentially the same as re._pattern_type
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))

_Key = TypeVar("_Key", bound=str)
_Key = TypeVar("_Key")
_Value = TypeVar("_Value")
_T = TypeVar("_T")

Expand Down
10 changes: 10 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true

[mypy-gevent.*]
ignore_missing_imports = True

[mypy-kerberos.*]
ignore_missing_imports = True

Expand All @@ -29,5 +32,12 @@ ignore_missing_imports = True
[mypy-snappy.*]
ignore_missing_imports = True

[mypy-test.*]
allow_redefinition = true
allow_untyped_globals = true

[mypy-winkerberos.*]
ignore_missing_imports = True

[mypy-xmlrunner.*]
ignore_missing_imports = True
5 changes: 2 additions & 3 deletions pymongo/socket_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

import errno
import select
import socket
import sys
from typing import Any, Optional
from typing import Any, Optional, Union

# PYTHON-2320: Jython does not fully support poll on SSL sockets,
# https://bugs.jython.org/issue2900
Expand All @@ -43,7 +42,7 @@ def __init__(self) -> None:
else:
self._poller = None

def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
"""Select for reads or writes with a timeout in seconds (or None).
Returns True if the socket is readable/writable, False on timeout.
Expand Down
2 changes: 1 addition & 1 deletion pymongo/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def maybe_decode(text):
def _resolve(*args, **kwargs):
if hasattr(resolver, 'resolve'):
# dnspython >= 2
return resolver.resolve(*args, **kwargs) # type: ignore
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions pymongo/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Type aliases used by PyMongo"""
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
Tuple, Type, TypeVar, Union)
Sequence, Tuple, Type, TypeVar, Union)

if TYPE_CHECKING:
from bson.raw_bson import RawBSONDocument
Expand All @@ -25,5 +25,5 @@
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
_Pipeline = List[Mapping[str, Any]]
_Pipeline = Sequence[Mapping[str, Any]]
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])
29 changes: 22 additions & 7 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from contextlib import contextmanager
from functools import wraps
from typing import Dict, no_type_check
from unittest import SkipTest

import pymongo
Expand All @@ -48,7 +49,9 @@
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.database import Database
from pymongo.hello import HelloCompat
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl
from pymongo.uri_parser import parse_uri
Expand Down Expand Up @@ -86,7 +89,7 @@
os.path.join(CERT_PATH, 'client.pem'))
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))

TLS_OPTIONS = dict(tls=True)
TLS_OPTIONS: Dict = dict(tls=True)
if CLIENT_PEM:
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
if CA_PEM:
Expand All @@ -102,13 +105,13 @@
# Remove after PYTHON-2712
from pymongo import pool
pool._MOCK_SERVICE_ID = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
Expand Down Expand Up @@ -184,6 +187,7 @@ def enable(self):
def __enter__(self):
self.enable()

@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
Expand Down Expand Up @@ -224,6 +228,8 @@ def _all_users(db):


class ClientContext(object):
client: MongoClient

MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI

def __init__(self):
Expand All @@ -247,9 +253,9 @@ def __init__(self):
self.tls = False
self.tlsCertificateKeyFile = False
self.server_is_resolvable = is_server_resolvable()
self.default_client_options = {}
self.default_client_options: Dict = {}
self.sessions_enabled = False
self.client = None
self.client = None # type: ignore
self.conn_lock = threading.Lock()
self.is_data_lake = False
self.load_balancer = TEST_LOADBALANCER
Expand Down Expand Up @@ -340,6 +346,7 @@ def _init_client(self):
try:
self.cmd_line = self.client.admin.command('getCmdLineOpts')
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
# Unauthorized.
Expand Down Expand Up @@ -418,6 +425,7 @@ def _init_client(self):
else:
self.server_parameters = self.client.admin.command(
'getParameter', '*')
assert self.cmd_line is not None
if 'enableTestCommands=1' in self.cmd_line['argv']:
self.test_commands_enabled = True
elif 'parsed' in self.cmd_line:
Expand All @@ -436,7 +444,8 @@ def _init_client(self):
self.mongoses.append(address)
if not self.serverless:
# Check for another mongos on the next port.
next_address = address[0], address[1] + 1
assert address is not None
next_address = address[0], address[1] + 1
mongos_client = self._connect(
*next_address, **self.default_client_options)
if mongos_client:
Expand Down Expand Up @@ -496,6 +505,7 @@ def _check_user_provided(self):
try:
return db_user in _all_users(client.admin)
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 18 or 'auth fails' in msg:
# Auth failed.
Expand All @@ -505,6 +515,7 @@ def _check_user_provided(self):

def _server_started_with_auth(self):
# MongoDB >= 2.0
assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
parsed = self.cmd_line['parsed']
# MongoDB >= 2.6
Expand All @@ -525,6 +536,7 @@ def _server_started_with_ipv6(self):
if not socket.has_ipv6:
return False

assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
return False
Expand Down Expand Up @@ -932,6 +944,9 @@ def fail_point(self, command_args):

class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""
client: MongoClient
db: Database
credentials: Dict[str, str]

@classmethod
@client_context.require_connection
Expand Down Expand Up @@ -1073,7 +1088,7 @@ def run(self, test):


if HAVE_XML:
class PymongoXMLTestRunner(XMLTestRunner):
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
def run(self, test):
setup()
result = super(PymongoXMLTestRunner, self).run(test)
Expand Down
1 change: 1 addition & 0 deletions test/auth_aws/test_auth_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


class TestAuthAWS(unittest.TestCase):
uri: str

@classmethod
def setUpClass(cls):
Expand Down
6 changes: 6 additions & 0 deletions test/mockupdb/test_cursor_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@


class TestCursorNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down Expand Up @@ -69,6 +72,9 @@ def op():


class TestKillCursorsNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_getmore_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_getmore_sharded(self):
servers = [MockupDB(), MockupDB()]

# Collect queries to either server in one queue.
q = Queue()
q: Queue = Queue()
for server in servers:
server.subscribe(q.put)
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',
Expand Down
4 changes: 2 additions & 2 deletions test/mockupdb/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def respond(r):
ServerApiVersion.V1))}
client = MongoClient("mongodb://"+primary.address_string,
appname='my app', # For _check_handshake_data()
**dict([k_map.get((k, v), (k, v)) for k, v
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
in kwargs.items()]))

self.addCleanup(client.close)
Expand All @@ -58,7 +58,7 @@ def respond(r):

# We do this checking here rather than in the autoresponder `respond()`
# because it runs in another Python thread so there are some funky things
# with error handling within that thread, and we want to be able to use
# with error handling within that thread, and we want to be able to use
# self.assertRaises().
self.handshake_req.assert_matches(protocol(hello, **kwargs))
_check_handshake_data(self.handshake_req)
Expand Down
4 changes: 2 additions & 2 deletions test/mockupdb/test_mixed_version_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_server(self, upgrade):
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()

# Collect queries to either server in one queue.
self.q = Queue()
self.q: Queue = Queue()
for server in self.mongos_old, self.mongos_new:
server.subscribe(self.q.put)
server.autoresponds('getlasterror')
Expand Down Expand Up @@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
def test(self):
self.setup_server(upgrade)
start = time.time()
servers_used = set()
servers_used: set = set()
while len(servers_used) < 2:
go(upgrade.function, self.client)
request = self.q.get(timeout=1)
Expand Down
2 changes: 2 additions & 0 deletions test/mockupdb/test_op_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@


class TestOpMsg(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
Expand Down
5 changes: 4 additions & 1 deletion test/mockupdb/test_op_msg_read_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import itertools
from typing import Any

from mockupdb import MockupDB, going, CommandBase
from pymongo import MongoClient, ReadPreference
Expand All @@ -27,6 +28,8 @@

class OpMsgReadPrefBase(unittest.TestCase):
single_mongod = False
primary: MockupDB
secondary: MockupDB

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -142,7 +145,7 @@ def test(self):
tag_sets=None)

client = self.setup_client(read_preference=pref)

expected_pref: Any
if operation.op_type == 'always-use-secondary':
expected_server = self.secondary
expected_pref = ReadPreference.SECONDARY
Expand Down
12 changes: 10 additions & 2 deletions test/mockupdb/test_reset_and_request_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class TestResetAndRequestCheck(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
self.ismaster_time = 0
self.ismaster_time = 0.0
self.client = None
self.server = None

Expand All @@ -45,7 +45,7 @@ def responder(request):
kwargs = {'socketTimeoutMS': 100}
# Disable retryable reads when pymongo supports it.
kwargs['retryReads'] = False
self.client = MongoClient(self.server.uri, **kwargs)
self.client = MongoClient(self.server.uri, **kwargs) # type: ignore
wait_until(lambda: self.client.nodes, 'connect to standalone')

def tearDown(self):
Expand All @@ -56,6 +56,8 @@ def _test_disconnect(self, operation):
# Application operation fails. Test that client resets server
# description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None

# Network error on application operation.
with self.assertRaises(ConnectionFailure):
Expand All @@ -81,6 +83,8 @@ def _test_timeout(self, operation):
# Application operation times out. Test that client does *not* reset
# server description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None

with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
Expand All @@ -91,6 +95,7 @@ def _test_timeout(self, operation):
# Server is *not* Unknown.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)

after = self.ismaster_time
Expand All @@ -99,6 +104,8 @@ def _test_timeout(self, operation):
def _test_not_master(self, operation):
# Application operation gets a "not master" error.
self.setup_server()
assert self.server is not None
assert self.client is not None

with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
Expand All @@ -110,6 +117,7 @@ def _test_not_master(self, operation):
# Server is rediscovered.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)

after = self.ismaster_time
Expand Down
Loading

0 comments on commit ee564c6

Please sign in to comment.