diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index b195d19..25b057d 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -204,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None: for socket in open_sockets: self.close_socket(socket) + def _register_connected_socket(self, key, socket): + """Register a socket as managed.""" + self._key_by_managed_socket[socket] = key + self._managed_socket_by_key[key] = socket + def _get_connected_socket( # pylint: disable=too-many-arguments self, addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], @@ -264,10 +269,6 @@ def free_socket(self, socket: SocketType) -> None: raise RuntimeError("Socket not managed") self._available_sockets.add(socket) - def _register_connected_socket(self, key, socket): - self._key_by_managed_socket[socket] = key - self._managed_socket_by_key[key] = socket - # pylint: disable=too-many-arguments def get_socket( self, @@ -290,7 +291,7 @@ def get_socket( used for multiple simultaneous connections to the same host :param float timeout: how long to wait to connect :param bool is_ssl: ``True`` If the connection is to be over SSL; - automatically set when ``proto`` is ``"https:"` + automatically set when ``proto`` is ``"https:"`` :param Optional[SSLContextType]: SSL context to use when making SSL requests """ if session_id: diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 9abbf98..46d053b 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free(): # get a socket for the same host, should be a different one with pytest.raises(RuntimeError) as context: socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Socket already connected" in str(context) + assert "An existing socket is already connected" in str(context) def test_get_socket_os_error(): @@ -105,9 +105,8 @@ def test_get_socket_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error(): @@ -121,9 +120,8 @@ def test_get_socket_runtime_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a RuntimeError - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: RuntimeError" in str(context) def test_get_socket_connect_memory_error(): @@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a MemoryError - with pytest.raises(RuntimeError) as context: + with pytest.raises(MemoryError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: MemoryError" in str(context) def test_get_socket_connect_os_error(): @@ -157,9 +154,8 @@ def test_get_socket_connect_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error_ties_again_at_least_one_free(): @@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once(): free_sockets_mock.assert_not_called() # try to get a socket that returns a RuntimeError twice - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)