From 99f8972dcec55a8a4ecb2db70aa58e5250b92910 Mon Sep 17 00:00:00 2001 From: Dan Halbert Date: Fri, 10 May 2024 16:47:37 -0400 Subject: [PATCH] Recover in more cases when a socket cannot be created. Also: - Clarify some documentation. Use sphinx argument documentation style. - Fix some typos. - Remove a few internal comments marking code sections. - Clarify an error message. - Internally, catch exceptions instead of passing them back. - Change one exception. - Update to pylint 3.1.0 so pre-commit can run under Python 3.12 --- .pre-commit-config.yaml | 2 +- adafruit_connection_manager.py | 115 +++++++++++++++------------------ 2 files changed, 52 insertions(+), 65 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77ed663..4d2e392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/pylint - rev: v2.17.4 + rev: v3.1.0 hooks: - id: pylint name: pylint (library code) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 5b8a10c..27f8a7b 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -21,8 +21,6 @@ """ -# imports - __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" @@ -31,9 +29,6 @@ WIZNET5K_SSL_SUPPORT_VERSION = (9, 1) -# typing - - if not sys.implementation.name == "circuitpython": from typing import List, Optional, Tuple @@ -46,9 +41,6 @@ ) -# ssl and pool helpers - - class _FakeSSLSocket: def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self._socket = socket @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio): return _global_ssl_contexts[_get_radio_hash_key(radio)] -# main class - - class ConnectionManager: - """A library for managing sockets accross libraries.""" + """A library for managing sockets across multiple hardware platforms and libraries.""" def __init__( self, @@ -224,23 +213,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments is_ssl: bool, ssl_context: Optional[SSLContextType] = None, ): - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except (OSError, RuntimeError) as exc: - return exc + + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) if is_ssl: socket = ssl_context.wrap_socket(socket, server_hostname=host) connect_host = host else: connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout + + # Set socket read and connect timeout. + socket.settimeout(timeout) try: socket.connect((connect_host, port)) - except (MemoryError, OSError) as exc: + except (MemoryError, OSError): + # If any connect problems, clean up and re-raise the problem exception. socket.close() - return exc + raise return socket @@ -269,11 +259,16 @@ def close_socket(self, socket: SocketType) -> None: self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: - """Mark a managed socket as available so it can be reused.""" + """Mark a managed socket as available so it can be reused. The socket is not closed.""" if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") self._available_sockets.add(socket) + def _register_connected_socket(self, key, socket): + self._key_by_managed_socket[socket] = key + self._managed_socket_by_key[key] = socket + + # pylint: disable=too-many-arguments def get_socket( self, host: str, @@ -281,70 +276,65 @@ def get_socket( proto: str, session_id: Optional[str] = None, *, - timeout: float = 1, + timeout: float = 1.0, is_ssl: bool = False, ssl_context: Optional[SSLContextType] = None, ) -> CircuitPythonSocketType: """ - Get a new socket and connect. - - - **host** *(str)* – The host you are want to connect to: "www.adaftuit.com" - - **port** *(int)* – The port you want to connect to: 80 - - **proto** *(str)* – The protocal you want to use: "http:" - - **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open - connections to the same host - - **timeout** *(float)* – Time timeout used for connecting - - **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is - "https:") - - **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL - requests + Get a new socket and connect to the given host. + + :param str host: host to connect to, such as ``"www.example.org"`` + :param int port: port to use for connection, such as ``80`` or ``443`` + :param str proto: connection protocol: ``"http:"``, ``"https:"``, etc. + :param Optional[str]: unique session ID, + used for multiple simultaneous connections to the same host + :param float timeout: how long to wait to connect + :param bool is_ssl: ``True`` If the connection is to be over SSL; + automatically set when ``proto`` is ``"https:"` + :param Optional[SSLContextType]: SSL context to use when making SSL requests """ if session_id: session_id = str(session_id) key = (host, port, proto, session_id) + + # Do we have already have a socket available for the requested connection? if key in self._managed_socket_by_key: socket = self._managed_socket_by_key[key] if socket in self._available_sockets: self._available_sockets.remove(socket) return socket - raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") + raise RuntimeError( + f"An existing socket is already connected to {proto}//{host}:{port}" + ) if proto == "https:": is_ssl = True if is_ssl and not ssl_context: - raise AttributeError( - "ssl_context must be set before using adafruit_requests for https" - ) + raise ValueError("ssl_context must be provided if using ssl") addr_info = self._socket_pool.getaddrinfo( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - first_exception = None - result = self._get_connected_socket( - addr_info, host, port, timeout, is_ssl, ssl_context - ) - if isinstance(result, Exception): - # Got an error, if there are any available sockets, free them and try again + try: + socket = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + self._register_connected_socket(key, socket) + return socket + except (MemoryError, OSError, RuntimeError): + # Could not get a new socket (or two, if SSL). + # If there are any available sockets, free them all and try again. if self.available_socket_count: - first_exception = result self._free_sockets() - result = self._get_connected_socket( + socket = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) - if isinstance(result, Exception): - last_result = f", first error: {first_exception}" if first_exception else "" - raise RuntimeError( - f"Error connecting socket: {result}{last_result}" - ) from result - - self._key_by_managed_socket[result] = key - self._managed_socket_by_key[key] = result - return result - - -# global helpers + self._register_connected_socket(key, socket) + return socket + # Re-raise exception if no sockets could be freed. + raise def connection_manager_close_all( @@ -353,9 +343,9 @@ def connection_manager_close_all( """ Close all open sockets for pool, optionally release references. - - **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close - sockets for, leave blank for all SocketPools - - **release_references** *(bool)* – Set to True if you want to also clear stored references to + :param Optional[SocketpoolModuleType] socket_pool: + a specific `SocketPool` whose sockets you want to close; `None`` means all `SocketPool`s + :param bool release_references: ``True`` if you want to also clear stored references to the SocketPool and SSL contexts """ if socket_pool: @@ -383,10 +373,7 @@ def connection_manager_close_all( def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """ - Get the ConnectionManager singleton for the given pool. - - - **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the - ConnectionManager for + Get or create the ConnectionManager singleton for the given pool. """ if socket_pool not in _global_connection_managers: _global_connection_managers[socket_pool] = ConnectionManager(socket_pool)