Skip to content

Commit

Permalink
A bit more test_ssl.py type check cleanup (#1405)
Browse files Browse the repository at this point in the history
We're basically down to a) a few hard things, b) typing a bunch of callbacks/utility functions in tests
  • Loading branch information
alex authored Jan 9, 2025
1 parent 9baefba commit a3972a0
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
TYPE_RSA,
X509,
PKey,
X509Name,
X509Store,
dump_certificate,
dump_privatekey,
Expand Down Expand Up @@ -3144,15 +3145,15 @@ def test_wantReadError(self) -> None:
conn.bio_read(1024)

@pytest.mark.parametrize("bufsize", [1.0, None, object(), "bufsize"])
def test_bio_read_wrong_args(self, bufsize) -> None:
def test_bio_read_wrong_args(self, bufsize: object) -> None:
"""
`Connection.bio_read` raises `TypeError` if passed a non-integer
argument.
"""
ctx = Context(SSLv23_METHOD)
conn = Connection(ctx, None)
with pytest.raises(TypeError):
conn.bio_read(bufsize)
conn.bio_read(bufsize) # type: ignore[arg-type]

def test_buffer_size(self) -> None:
"""
Expand Down Expand Up @@ -3287,7 +3288,9 @@ class TestConnectionRecvInto:
Tests for `Connection.recv_into`.
"""

def _no_length_test(self, factory):
def _no_length_test(
self, factory: typing.Callable[[int], typing.Any]
) -> None:
"""
Assert that when the given buffer is passed to `Connection.recv_into`,
whatever bytes are available to be received that fit into that buffer
Expand All @@ -3308,7 +3311,9 @@ def test_bytearray_no_length(self) -> None:
"""
self._no_length_test(bytearray)

def _respects_length_test(self, factory):
def _respects_length_test(
self, factory: typing.Callable[[int], typing.Any]
) -> None:
"""
Assert that when the given buffer is passed to `Connection.recv_into`
along with a value for `nbytes` that is less than the size of that
Expand All @@ -3330,7 +3335,9 @@ def test_bytearray_respects_length(self) -> None:
"""
self._respects_length_test(bytearray)

def _doesnt_overfill_test(self, factory):
def _doesnt_overfill_test(
self, factory: typing.Callable[[int], typing.Any]
) -> None:
"""
Assert that if there are more bytes available to be read from the
receive buffer than would fit into the buffer passed to
Expand Down Expand Up @@ -3881,7 +3888,9 @@ def test_unexpected_EOF(self) -> None:
(54, "ECONNRESET"),
]

def _check_client_ca_list(self, func):
def _check_client_ca_list(
self, func: typing.Callable[[Context], list[X509Name]]
) -> None:
"""
Verify the return value of the `get_client_ca_list` method for
server and client connections.
Expand Down Expand Up @@ -3912,9 +3921,9 @@ def test_set_client_ca_list_errors(self) -> None:
"""
ctx = Context(SSLv23_METHOD)
with pytest.raises(TypeError):
ctx.set_client_ca_list("spam")
ctx.set_client_ca_list("spam") # type: ignore[arg-type]
with pytest.raises(TypeError):
ctx.set_client_ca_list(["spam"])
ctx.set_client_ca_list(["spam"]) # type: ignore[list-item]

def test_set_empty_ca_list(self) -> None:
"""
Expand Down Expand Up @@ -4016,7 +4025,7 @@ def test_add_client_ca_wrong_args(self) -> None:
"""
ctx = Context(SSLv23_METHOD)
with pytest.raises(TypeError):
ctx.add_client_ca("spam")
ctx.add_client_ca("spam") # type: ignore[arg-type]

def test_one_add_client_ca(self) -> None:
"""
Expand Down Expand Up @@ -4142,7 +4151,7 @@ def test_available(self) -> None:
results = []

@feature_guard
def inner():
def inner() -> bool:
results.append(True)
return True

Expand All @@ -4157,7 +4166,7 @@ def test_unavailable(self) -> None:
feature_guard = _make_requires(False, "Error text")

@feature_guard
def inner(): # pragma: nocover
def inner() -> None: # pragma: nocover
pytest.fail("Should not be called")

with pytest.raises(NotImplementedError) as e:
Expand All @@ -4180,7 +4189,7 @@ def _client_connection(
self,
callback: typing.Callable[[Connection, bytes, T | None], bool],
data: T | None,
request_ocsp=True,
request_ocsp: bool = True,
) -> Connection:
"""
Builds a client connection suitable for using OCSP.
Expand Down Expand Up @@ -4586,6 +4595,7 @@ def pump() -> None:
s_listening = False
s_handshaking = True
# Write the duplicate ClientHello. See giant comment above.
assert latest_client_hello is not None
s.bio_write(latest_client_hello)

if s_handshaking:
Expand Down

0 comments on commit a3972a0

Please sign in to comment.