Skip to content

Commit

Permalink
Recover in more cases when a socket cannot be created.
Browse files Browse the repository at this point in the history
Also:
- Clarify some documentation. Use sphinx argument documentation style.
- Fix some typos.
- Remove a few internal comments marking code sections.
- Clarify an error message.
- Internally, catch exceptions instead of passing them back.
- Change one exception.
- Update to pylint 3.1.0 so pre-commit can run under Python 3.12
  • Loading branch information
dhalbert committed May 10, 2024
1 parent 2c79732 commit 99f8972
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/pylint
rev: v2.17.4
rev: v3.1.0
hooks:
- id: pylint
name: pylint (library code)
Expand Down
115 changes: 51 additions & 64 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
"""

# imports

__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git"

Expand All @@ -31,9 +29,6 @@

WIZNET5K_SSL_SUPPORT_VERSION = (9, 1)

# typing


if not sys.implementation.name == "circuitpython":
from typing import List, Optional, Tuple

Expand All @@ -46,9 +41,6 @@
)


# ssl and pool helpers


class _FakeSSLSocket:
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self._socket = socket
Expand Down Expand Up @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio):
return _global_ssl_contexts[_get_radio_hash_key(radio)]


# main class


class ConnectionManager:
"""A library for managing sockets accross libraries."""
"""A library for managing sockets across multiple hardware platforms and libraries."""

def __init__(
self,
Expand Down Expand Up @@ -224,23 +213,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments
is_ssl: bool,
ssl_context: Optional[SSLContextType] = None,
):
try:
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except (OSError, RuntimeError) as exc:
return exc

socket = self._socket_pool.socket(addr_info[0], addr_info[1])

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

# Set socket read and connect timeout.
socket.settimeout(timeout)

try:
socket.connect((connect_host, port))
except (MemoryError, OSError) as exc:
except (MemoryError, OSError):
# If any connect problems, clean up and re-raise the problem exception.
socket.close()
return exc
raise

return socket

Expand Down Expand Up @@ -269,82 +259,82 @@ def close_socket(self, socket: SocketType) -> None:
self._available_sockets.remove(socket)

def free_socket(self, socket: SocketType) -> None:
"""Mark a managed socket as available so it can be reused."""
"""Mark a managed socket as available so it can be reused. The socket is not closed."""
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
self._available_sockets.add(socket)

def _register_connected_socket(self, key, socket):
self._key_by_managed_socket[socket] = key
self._managed_socket_by_key[key] = socket

# pylint: disable=too-many-arguments
def get_socket(
self,
host: str,
port: int,
proto: str,
session_id: Optional[str] = None,
*,
timeout: float = 1,
timeout: float = 1.0,
is_ssl: bool = False,
ssl_context: Optional[SSLContextType] = None,
) -> CircuitPythonSocketType:
"""
Get a new socket and connect.
- **host** *(str)* – The host you are want to connect to: "www.adaftuit.com"
- **port** *(int)* – The port you want to connect to: 80
- **proto** *(str)* – The protocal you want to use: "http:"
- **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open
connections to the same host
- **timeout** *(float)* – Time timeout used for connecting
- **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is
"https:")
- **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL
requests
Get a new socket and connect to the given host.
:param str host: host to connect to, such as ``"www.example.org"``
:param int port: port to use for connection, such as ``80`` or ``443``
:param str proto: connection protocol: ``"http:"``, ``"https:"``, etc.
:param Optional[str]: unique session ID,
used for multiple simultaneous connections to the same host
:param float timeout: how long to wait to connect
:param bool is_ssl: ``True`` If the connection is to be over SSL;
automatically set when ``proto`` is ``"https:"`
:param Optional[SSLContextType]: SSL context to use when making SSL requests
"""
if session_id:
session_id = str(session_id)
key = (host, port, proto, session_id)

# Do we have already have a socket available for the requested connection?
if key in self._managed_socket_by_key:
socket = self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)
return socket

raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
raise RuntimeError(
f"An existing socket is already connected to {proto}//{host}:{port}"
)

if proto == "https:":
is_ssl = True
if is_ssl and not ssl_context:
raise AttributeError(
"ssl_context must be set before using adafruit_requests for https"
)
raise ValueError("ssl_context must be provided if using ssl")

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

first_exception = None
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
# Got an error, if there are any available sockets, free them and try again
try:
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
self._register_connected_socket(key, socket)
return socket
except (MemoryError, OSError, RuntimeError):
# Could not get a new socket (or two, if SSL).
# If there are any available sockets, free them all and try again.
if self.available_socket_count:
first_exception = result
self._free_sockets()
result = self._get_connected_socket(
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
last_result = f", first error: {first_exception}" if first_exception else ""
raise RuntimeError(
f"Error connecting socket: {result}{last_result}"
) from result

self._key_by_managed_socket[result] = key
self._managed_socket_by_key[key] = result
return result


# global helpers
self._register_connected_socket(key, socket)
return socket
# Re-raise exception if no sockets could be freed.
raise


def connection_manager_close_all(
Expand All @@ -353,9 +343,9 @@ def connection_manager_close_all(
"""
Close all open sockets for pool, optionally release references.
- **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close
sockets for, leave blank for all SocketPools
- **release_references** *(bool)* – Set to True if you want to also clear stored references to
:param Optional[SocketpoolModuleType] socket_pool:
a specific `SocketPool` whose sockets you want to close; `None`` means all `SocketPool`s
:param bool release_references: ``True`` if you want to also clear stored references to
the SocketPool and SSL contexts
"""
if socket_pool:
Expand Down Expand Up @@ -383,10 +373,7 @@ def connection_manager_close_all(

def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""
Get the ConnectionManager singleton for the given pool.
- **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the
ConnectionManager for
Get or create the ConnectionManager singleton for the given pool.
"""
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
Expand Down

0 comments on commit 99f8972

Please sign in to comment.