From 9422c3688eec75f825949e367ea7cfb26b1c49c5 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Thu, 9 Jan 2025 14:02:43 -0500 Subject: [PATCH] Bring us under 30 test_ssl type-check issues (#1406) --- tests/test_ssl.py | 110 +++++++++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 46 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index b59b7578..f28fa05e 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -901,7 +901,9 @@ def test_set_passwd_cb(self, tmpfile: bytes) -> None: pemFile = self._write_encrypted_pem(passphrase, tmpfile) calledWith = [] - def passphraseCallback(maxlen, verify, extra): + def passphraseCallback( + maxlen: int, verify: bool, extra: None + ) -> bytes: calledWith.append((maxlen, verify, extra)) return passphrase @@ -920,7 +922,9 @@ def test_passwd_callback_exception(self, tmpfile: bytes) -> None: """ pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) - def passphraseCallback(maxlen, verify, extra): + def passphraseCallback( + maxlen: int, verify: bool, extra: None + ) -> bytes: raise RuntimeError("Sorry, I am a fail.") context = Context(SSLv23_METHOD) @@ -935,7 +939,9 @@ def test_passwd_callback_false(self, tmpfile: bytes) -> None: """ pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) - def passphraseCallback(maxlen, verify, extra): + def passphraseCallback( + maxlen: int, verify: bool, extra: None + ) -> bytes: return b"" context = Context(SSLv23_METHOD) @@ -950,11 +956,11 @@ def test_passwd_callback_non_string(self, tmpfile: bytes) -> None: """ pemFile = self._write_encrypted_pem(b"monkeys are nice", tmpfile) - def passphraseCallback(maxlen, verify, extra): + def passphraseCallback(maxlen: int, verify: bool, extra: None) -> int: return 10 context = Context(SSLv23_METHOD) - context.set_passwd_cb(passphraseCallback) + context.set_passwd_cb(passphraseCallback) # type: ignore[arg-type] # TODO: Surely this is the wrong error? with pytest.raises(ValueError): context.use_privatekey_file(pemFile) @@ -968,7 +974,9 @@ def test_passwd_callback_too_long(self, tmpfile: bytes) -> None: passphrase = b"x" * 1024 pemFile = self._write_encrypted_pem(passphrase, tmpfile) - def passphraseCallback(maxlen, verify, extra): + def passphraseCallback( + maxlen: int, verify: bool, extra: None + ) -> bytes: assert maxlen == 1024 return passphrase + b"y" @@ -990,7 +998,7 @@ def test_set_info_callback(self) -> None: called = [] - def info(conn, where, ret): + def info(conn: Connection, where: int, ret: int) -> None: called.append((conn, where, ret)) context = Context(SSLv23_METHOD) @@ -1028,7 +1036,7 @@ def test_set_keylog_callback(self) -> None: """ called = [] - def keylog(conn, line): + def keylog(conn: Connection, line: bytes) -> None: called.append((conn, line)) server_context = Context(TLSv1_2_METHOD) @@ -1385,9 +1393,9 @@ def test_set_verify_callback_connection_argument(self) -> None: serverConnection = Connection(serverContext, None) class VerifyCallback: - def callback(self, connection, *args): + def callback(self, connection: Connection, *args) -> bool: self.connection = connection - return 1 + return True verify = VerifyCallback() clientContext = Context(SSLv23_METHOD) @@ -1415,9 +1423,11 @@ def test_x509_in_verify_works(self) -> None: ) serverConnection = Connection(serverContext, None) - def verify_cb_get_subject(conn, cert, errnum, depth, ok): + def verify_cb_get_subject( + conn: Connection, cert: X509, errnum: int, depth: int, ok: int + ) -> bool: assert cert.get_subject() - return 1 + return True clientContext = Context(SSLv23_METHOD) clientContext.set_verify(VERIFY_PEER, verify_cb_get_subject) @@ -1817,10 +1827,10 @@ def test_old_callback_forgotten(self) -> None: a new callback, the one it replaces is dereferenced. """ - def callback(connection): # pragma: no cover + def callback(connection: Connection) -> None: # pragma: no cover pass - def replacement(connection): # pragma: no cover + def replacement(connection: Connection) -> None: # pragma: no cover pass context = Context(SSLv23_METHOD) @@ -1851,7 +1861,7 @@ def test_no_servername(self) -> None: """ args = [] - def servername(conn): + def servername(conn: Connection) -> None: args.append((conn, conn.get_servername())) context = Context(SSLv23_METHOD) @@ -1888,7 +1898,7 @@ def test_servername(self) -> None: """ args = [] - def servername(conn): + def servername(conn: Connection) -> None: args.append((conn, conn.get_servername())) context = Context(SSLv23_METHOD) @@ -1926,7 +1936,7 @@ def test_alpn_success(self) -> None: """ select_args = [] - def select(conn, options): + def select(conn: Connection, options: list[bytes]) -> bytes: select_args.append((conn, options)) return b"spdy/2" @@ -1974,7 +1984,7 @@ def test_alpn_set_on_connection(self) -> None: """ select_args = [] - def select(conn, options): + def select(conn: Connection, options: list[bytes]) -> bytes: select_args.append((conn, options)) return b"spdy/2" @@ -2015,7 +2025,7 @@ def test_alpn_server_fail(self) -> None: """ select_args = [] - def select(conn, options): + def select(conn: Connection, options: list[bytes]) -> bytes: select_args.append((conn, options)) return b"" @@ -2054,7 +2064,7 @@ def test_alpn_no_server_overlap(self) -> None: """ refusal_args = [] - def refusal(conn, options): + def refusal(conn: Connection, options: list[bytes]): refusal_args.append((conn, options)) return NO_OVERLAPPING_PROTOCOLS @@ -2094,7 +2104,7 @@ def test_alpn_select_cb_returns_invalid_value(self) -> None: """ invalid_cb_args = [] - def invalid_cb(conn, options): + def invalid_cb(conn: Connection, options: list[bytes]) -> str: invalid_cb_args.append((conn, options)) return "can't return unicode" @@ -2102,7 +2112,7 @@ def invalid_cb(conn, options): client_context.set_alpn_protos([b"http/1.1", b"spdy/2"]) server_context = Context(SSLv23_METHOD) - server_context.set_alpn_select_callback(invalid_cb) + server_context.set_alpn_select_callback(invalid_cb) # type: ignore[arg-type] # Necessary to actually accept the connection server_context.use_privatekey( @@ -2163,7 +2173,7 @@ def test_alpn_callback_exception(self) -> None: """ select_args = [] - def select(conn, options): + def select(conn: Connection, options: list[bytes]) -> bytes: select_args.append((conn, options)) raise TypeError() @@ -2790,8 +2800,10 @@ def test_set_verify_callback_reference(self) -> None: the context and all connections created by it do not use it anymore. """ - def callback(conn, cert, errnum, depth, ok): # pragma: no cover - return ok + def callback( + conn: Connection, cert: X509, errnum: int, depth: int, ok: int + ) -> bool: # pragma: no cover + return bool(ok) tracker = ref(callback) @@ -2872,7 +2884,7 @@ def test_client_set_session(self) -> None: ctx.use_certificate(cert) ctx.set_session_id(b"unity-test") - def makeServer(socket): + def makeServer(socket: socket) -> Connection: server = Connection(ctx, socket) server.set_accept_state() return server @@ -2881,7 +2893,7 @@ def makeServer(socket): originalSession = originalClient.get_session() assert originalSession is not None - def makeClient(socket): + def makeClient(socket: socket) -> Connection: client = loopback_client_factory(socket) client.set_session(originalSession) return client @@ -2914,12 +2926,12 @@ def test_set_session_wrong_method(self) -> None: ctx.use_certificate(cert) ctx.set_session_id(b"unity-test") - def makeServer(socket): + def makeServer(socket: socket) -> Connection: server = Connection(ctx, socket) server.set_accept_state() return server - def makeOriginalClient(socket): + def makeOriginalClient(socket: socket) -> Connection: client = Connection(Context(v1), socket) client.set_connect_state() return client @@ -2930,7 +2942,7 @@ def makeOriginalClient(socket): originalSession = originalClient.get_session() assert originalSession is not None - def makeClient(socket): + def makeClient(socket: socket) -> Connection: # Intentionally use a different, incompatible method here. client = Connection(Context(v2), socket) client.set_connect_state() @@ -3193,7 +3205,7 @@ class VeryLarge(bytes): Mock object so that we don't have to allocate 2**31 bytes """ - def __len__(self): + def __len__(self) -> int: return 2**31 @@ -3275,7 +3287,7 @@ def test_buf_too_large(self) -> None: exc_info.match(r"Cannot send more than .+ bytes at once") -def _make_memoryview(size): +def _make_memoryview(size: int) -> memoryview: """ Create a new ``memoryview`` wrapped around a ``bytearray`` of the given size. @@ -3933,7 +3945,7 @@ def test_set_empty_ca_list(self) -> None: after the connection is set up. """ - def no_ca(ctx): + def no_ca(ctx: Context) -> list[X509Name]: ctx.set_client_ca_list([]) return [] @@ -3950,7 +3962,7 @@ def test_set_one_ca_list(self) -> None: cacert = load_certificate(FILETYPE_PEM, root_cert_pem) cadesc = cacert.get_subject() - def single_ca(ctx): + def single_ca(ctx: Context) -> list[X509Name]: ctx.set_client_ca_list([cadesc]) return [cadesc] @@ -3970,7 +3982,7 @@ def test_set_multiple_ca_list(self) -> None: sedesc = secert.get_subject() cldesc = clcert.get_subject() - def multiple_ca(ctx): + def multiple_ca(ctx: Context) -> list[X509Name]: L = [sedesc, cldesc] ctx.set_client_ca_list(L) return L @@ -3991,7 +4003,7 @@ def test_reset_ca_list(self) -> None: sedesc = secert.get_subject() cldesc = clcert.get_subject() - def changed_ca(ctx): + def changed_ca(ctx: Context) -> list[X509Name]: ctx.set_client_ca_list([sedesc, cldesc]) ctx.set_client_ca_list([cadesc]) return [cadesc] @@ -4010,7 +4022,7 @@ def test_mutated_ca_list(self) -> None: cadesc = cacert.get_subject() sedesc = secert.get_subject() - def mutated_ca(ctx): + def mutated_ca(ctx: Context) -> list[X509Name]: L = [cadesc] ctx.set_client_ca_list([cadesc]) L.append(sedesc) @@ -4035,7 +4047,7 @@ def test_one_add_client_ca(self) -> None: cacert = load_certificate(FILETYPE_PEM, root_cert_pem) cadesc = cacert.get_subject() - def single_ca(ctx): + def single_ca(ctx: Context) -> list[X509Name]: ctx.add_client_ca(cacert) return [cadesc] @@ -4052,7 +4064,7 @@ def test_multiple_add_client_ca(self) -> None: cadesc = cacert.get_subject() sedesc = secert.get_subject() - def multiple_ca(ctx): + def multiple_ca(ctx: Context) -> list[X509Name]: ctx.add_client_ca(cacert) ctx.add_client_ca(secert.to_cryptography()) return [cadesc, sedesc] @@ -4073,7 +4085,7 @@ def test_set_and_add_client_ca(self) -> None: sedesc = secert.get_subject() cldesc = clcert.get_subject() - def mixed_set_add_ca(ctx): + def mixed_set_add_ca(ctx: Context) -> list[X509Name]: ctx.set_client_ca_list([cadesc, sedesc]) ctx.add_client_ca(clcert) return [cadesc, sedesc, cldesc] @@ -4093,7 +4105,7 @@ def test_set_after_add_client_ca(self) -> None: cadesc = cacert.get_subject() sedesc = secert.get_subject() - def set_replaces_add_ca(ctx): + def set_replaces_add_ca(ctx: Context) -> list[X509Name]: ctx.add_client_ca(clcert.to_cryptography()) ctx.set_client_ca_list([cadesc]) ctx.add_client_ca(secert) @@ -4253,7 +4265,9 @@ def test_client_negotiates_without_server(self) -> None: """ called = [] - def ocsp_callback(conn, ocsp_data, ignored): + def ocsp_callback( + conn: Connection, ocsp_data: bytes, ignored: None + ) -> bool: called.append(ocsp_data) return True @@ -4273,7 +4287,9 @@ def test_client_receives_servers_data(self) -> None: def server_callback(*args, **kwargs): return self.sample_ocsp_data - def client_callback(conn, ocsp_data, ignored): + def client_callback( + conn: Connection, ocsp_data: bytes, ignored: None + ) -> bool: calls.append(ocsp_data) return True @@ -4347,7 +4363,9 @@ def test_server_returns_empty_string(self) -> None: def server_callback(*args): return b"" - def client_callback(conn, ocsp_data, ignored): + def client_callback( + conn: Connection, ocsp_data: bytes, ignored: None + ) -> bool: client_calls.append(ocsp_data) return True @@ -4509,10 +4527,10 @@ class TestDTLS: def _test_handshake_and_data(self, srtp_profile: bytes | None) -> None: s_ctx = Context(DTLS_METHOD) - def generate_cookie(ssl): + def generate_cookie(ssl: Connection) -> bytes: return b"xyzzy" - def verify_cookie(ssl, cookie): + def verify_cookie(ssl: Connection, cookie: bytes) -> bool: return cookie == b"xyzzy" s_ctx.set_cookie_generate_callback(generate_cookie)