Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-3064 Add typings to test package #844

Merged
merged 20 commits into from
Feb 8, 2022
3 changes: 2 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ jobs:
pip install -e ".[zstd, srv]"
- name: Run mypy
run: |
mypy --install-types --non-interactive bson gridfs tools
mypy --install-types --non-interactive bson gridfs tools pymongo
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test
2 changes: 1 addition & 1 deletion bson/son.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# This is essentially the same as re._pattern_type
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))

_Key = TypeVar("_Key", bound=str)
_Key = TypeVar("_Key")
_Value = TypeVar("_Value")
_T = TypeVar("_T")

Expand Down
10 changes: 10 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true

[mypy-gevent.*]
ignore_missing_imports = True

[mypy-kerberos.*]
ignore_missing_imports = True

Expand All @@ -29,5 +32,12 @@ ignore_missing_imports = True
[mypy-snappy.*]
ignore_missing_imports = True

[mypy-test.*]
allow_redefinition = true
ShaneHarvey marked this conversation as resolved.
Show resolved Hide resolved
allow_untyped_globals = true

[mypy-winkerberos.*]
ignore_missing_imports = True

[mypy-xmlrunner.*]
ignore_missing_imports = True
5 changes: 2 additions & 3 deletions pymongo/socket_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

import errno
import select
import socket
import sys
from typing import Any, Optional
from typing import Any, Optional, Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union is now unused.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


# PYTHON-2320: Jython does not fully support poll on SSL sockets,
# https://bugs.jython.org/issue2900
Expand All @@ -43,7 +42,7 @@ def __init__(self) -> None:
else:
self._poller = None

def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
"""Select for reads or writes with a timeout in seconds (or None).

Returns True if the socket is readable/writable, False on timeout.
Expand Down
2 changes: 1 addition & 1 deletion pymongo/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def maybe_decode(text):
def _resolve(*args, **kwargs):
if hasattr(resolver, 'resolve'):
# dnspython >= 2
return resolver.resolve(*args, **kwargs) # type: ignore
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions pymongo/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Type aliases used by PyMongo"""
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
Tuple, Type, TypeVar, Union)
Sequence, Tuple, Type, TypeVar, Union)

if TYPE_CHECKING:
from bson.raw_bson import RawBSONDocument
Expand All @@ -25,5 +25,5 @@
_Address = Tuple[str, Optional[int]]
_CollationIn = Union[Mapping[str, Any], "Collation"]
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
_Pipeline = List[Mapping[str, Any]]
_Pipeline = Sequence[Mapping[str, Any]]
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])
29 changes: 22 additions & 7 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from contextlib import contextmanager
from functools import wraps
from typing import Dict, no_type_check
from unittest import SkipTest

import pymongo
Expand All @@ -48,7 +49,9 @@
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.database import Database
from pymongo.hello import HelloCompat
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl
from pymongo.uri_parser import parse_uri
Expand Down Expand Up @@ -86,7 +89,7 @@
os.path.join(CERT_PATH, 'client.pem'))
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))

TLS_OPTIONS = dict(tls=True)
TLS_OPTIONS: Dict = dict(tls=True)
if CLIENT_PEM:
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
if CA_PEM:
Expand All @@ -102,13 +105,13 @@
# Remove after PYTHON-2712
from pymongo import pool
pool._MOCK_SERVICE_ID = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
elif TEST_SERVERLESS:
TEST_LOADBALANCER = True
res = parse_uri(SINGLE_MONGOS_LB_URI)
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
host, port = res['nodelist'][0]
db_user = res['username'] or db_user
db_pwd = res['password'] or db_pwd
Expand Down Expand Up @@ -184,6 +187,7 @@ def enable(self):
def __enter__(self):
self.enable()

@no_type_check
def disable(self):
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
Expand Down Expand Up @@ -224,6 +228,8 @@ def _all_users(db):


class ClientContext(object):
client: MongoClient

MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI

def __init__(self):
Expand All @@ -247,9 +253,9 @@ def __init__(self):
self.tls = False
self.tlsCertificateKeyFile = False
self.server_is_resolvable = is_server_resolvable()
self.default_client_options = {}
self.default_client_options: Dict = {}
self.sessions_enabled = False
self.client = None
self.client = None # type: ignore
self.conn_lock = threading.Lock()
self.is_data_lake = False
self.load_balancer = TEST_LOADBALANCER
Expand Down Expand Up @@ -340,6 +346,7 @@ def _init_client(self):
try:
self.cmd_line = self.client.admin.command('getCmdLineOpts')
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
# Unauthorized.
Expand Down Expand Up @@ -418,6 +425,7 @@ def _init_client(self):
else:
self.server_parameters = self.client.admin.command(
'getParameter', '*')
assert self.cmd_line is not None
if 'enableTestCommands=1' in self.cmd_line['argv']:
self.test_commands_enabled = True
elif 'parsed' in self.cmd_line:
Expand All @@ -436,7 +444,8 @@ def _init_client(self):
self.mongoses.append(address)
if not self.serverless:
# Check for another mongos on the next port.
next_address = address[0], address[1] + 1
assert address is not None
next_address = address[0], address[1] + 1
mongos_client = self._connect(
*next_address, **self.default_client_options)
if mongos_client:
Expand Down Expand Up @@ -496,6 +505,7 @@ def _check_user_provided(self):
try:
return db_user in _all_users(client.admin)
except pymongo.errors.OperationFailure as e:
assert e.details is not None
msg = e.details.get('errmsg', '')
if e.code == 18 or 'auth fails' in msg:
# Auth failed.
Expand All @@ -505,6 +515,7 @@ def _check_user_provided(self):

def _server_started_with_auth(self):
# MongoDB >= 2.0
assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
parsed = self.cmd_line['parsed']
# MongoDB >= 2.6
Expand All @@ -525,6 +536,7 @@ def _server_started_with_ipv6(self):
if not socket.has_ipv6:
return False

assert self.cmd_line is not None
if 'parsed' in self.cmd_line:
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
return False
Expand Down Expand Up @@ -932,6 +944,9 @@ def fail_point(self, command_args):

class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""
client: MongoClient
db: Database
credentials: Dict[str, str]

@classmethod
@client_context.require_connection
Expand Down Expand Up @@ -1073,7 +1088,7 @@ def run(self, test):


if HAVE_XML:
class PymongoXMLTestRunner(XMLTestRunner):
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
def run(self, test):
setup()
result = super(PymongoXMLTestRunner, self).run(test)
Expand Down
1 change: 1 addition & 0 deletions test/auth_aws/test_auth_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


class TestAuthAWS(unittest.TestCase):
uri: str

@classmethod
def setUpClass(cls):
Expand Down
6 changes: 6 additions & 0 deletions test/mockupdb/test_cursor_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@


class TestCursorNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down Expand Up @@ -69,6 +72,9 @@ def op():


class TestKillCursorsNamespace(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_getmore_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_getmore_sharded(self):
servers = [MockupDB(), MockupDB()]

# Collect queries to either server in one queue.
q = Queue()
q: Queue = Queue()
for server in servers:
server.subscribe(q.put)
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',
Expand Down
6 changes: 3 additions & 3 deletions test/mockupdb/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ def respond(r):
ServerApiVersion.V1))}
client = MongoClient("mongodb://"+primary.address_string,
appname='my app', # For _check_handshake_data()
**dict([k_map.get((k, v), (k, v)) for k, v
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
in kwargs.items()]))

self.addCleanup(client.close)

# We have an autoresponder luckily, so no need for `go()`.
assert client.db.command(hello)

# We do this checking here rather than in the autoresponder `respond()`
# because it runs in another Python thread so there are some funky things
# with error handling within that thread, and we want to be able to use
# with error handling within that thread, and we want to be able to use
# self.assertRaises().
self.handshake_req.assert_matches(protocol(hello, **kwargs))
_check_handshake_data(self.handshake_req)
Expand Down
4 changes: 2 additions & 2 deletions test/mockupdb/test_mixed_version_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_server(self, upgrade):
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()

# Collect queries to either server in one queue.
self.q = Queue()
self.q: Queue = Queue()
for server in self.mongos_old, self.mongos_new:
server.subscribe(self.q.put)
server.autoresponds('getlasterror')
Expand Down Expand Up @@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
def test(self):
self.setup_server(upgrade)
start = time.time()
servers_used = set()
servers_used: set = set()
while len(servers_used) < 2:
go(upgrade.function, self.client)
request = self.q.get(timeout=1)
Expand Down
2 changes: 2 additions & 0 deletions test/mockupdb/test_op_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@


class TestOpMsg(unittest.TestCase):
server: MockupDB
client: MongoClient

@classmethod
def setUpClass(cls):
Expand Down
5 changes: 4 additions & 1 deletion test/mockupdb/test_op_msg_read_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import itertools
from typing import Any

from mockupdb import MockupDB, going, CommandBase
from pymongo import MongoClient, ReadPreference
Expand All @@ -27,6 +28,8 @@

class OpMsgReadPrefBase(unittest.TestCase):
single_mongod = False
primary: MockupDB
secondary: MockupDB

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -142,7 +145,7 @@ def test(self):
tag_sets=None)

client = self.setup_client(read_preference=pref)

expected_pref: Any
if operation.op_type == 'always-use-secondary':
expected_server = self.secondary
expected_pref = ReadPreference.SECONDARY
Expand Down
12 changes: 10 additions & 2 deletions test/mockupdb/test_reset_and_request_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class TestResetAndRequestCheck(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
self.ismaster_time = 0
self.ismaster_time = 0.0
self.client = None
self.server = None

Expand All @@ -45,7 +45,7 @@ def responder(request):
kwargs = {'socketTimeoutMS': 100}
# Disable retryable reads when pymongo supports it.
kwargs['retryReads'] = False
self.client = MongoClient(self.server.uri, **kwargs)
self.client = MongoClient(self.server.uri, **kwargs) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with the type ignore here? It seems perfectly reasonable to pass **kwargs to MongoClient because it accepts **kwargs. Is this something we should report to mypy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is a type limitation of mypy, I'll look for/file an issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found this comment: I'll try TypedDict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it really is a bug, already reported: python/mypy#8862

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_until(lambda: self.client.nodes, 'connect to standalone')

def tearDown(self):
Expand All @@ -56,6 +56,8 @@ def _test_disconnect(self, operation):
# Application operation fails. Test that client resets server
# description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None

# Network error on application operation.
with self.assertRaises(ConnectionFailure):
Expand All @@ -81,6 +83,8 @@ def _test_timeout(self, operation):
# Application operation times out. Test that client does *not* reset
# server description and does *not* schedule immediate check.
self.setup_server()
assert self.server is not None
assert self.client is not None

with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
Expand All @@ -91,6 +95,7 @@ def _test_timeout(self, operation):
# Server is *not* Unknown.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)

after = self.ismaster_time
Expand All @@ -99,6 +104,8 @@ def _test_timeout(self, operation):
def _test_not_master(self, operation):
# Application operation gets a "not master" error.
self.setup_server()
assert self.server is not None
assert self.client is not None

with self.assertRaises(ConnectionFailure):
with going(operation.function, self.client):
Expand All @@ -110,6 +117,7 @@ def _test_not_master(self, operation):
# Server is rediscovered.
topology = self.client._topology
server = topology.select_server_by_address(self.server.address, 0)
assert server is not None
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)

after = self.ismaster_time
Expand Down
Loading