diff --git a/omnikinverter/omnikinverter.py b/omnikinverter/omnikinverter.py index ae81bda5..fecf2477 100644 --- a/omnikinverter/omnikinverter.py +++ b/omnikinverter/omnikinverter.py @@ -137,16 +137,24 @@ async def tcp_request(self) -> dict[str, Any]: raise OmnikInverterAuthError(msg) try: - reader, writer = await asyncio.open_connection(self.host, self.tcp_port) + async with async_timeout.timeout(self.request_timeout): + reader, writer = await asyncio.open_connection(self.host, self.tcp_port) except OSError as exception: msg = "Failed to open a TCP connection to the Omnik Inverter device" raise OmnikInverterConnectionError(msg) from exception + except asyncio.TimeoutError as exception: # pragma: no cover + msg = "Timeout occurred while connecting to the Omnik Inverter device" + raise OmnikInverterConnectionError(msg) from exception try: - writer.write(tcp.create_information_request(self.serial_number)) - await writer.drain() + async with async_timeout.timeout(self.request_timeout): + writer.write(tcp.create_information_request(self.serial_number)) + await writer.drain() - raw_msg = await reader.read(1024) + raw_msg = await reader.read(1024) + except asyncio.TimeoutError as exception: + msg = "Timeout occurred while communicating with the Omnik Inverter device" + raise OmnikInverterConnectionError(msg) from exception finally: writer.close() try: diff --git a/tests/test_tcp_models.py b/tests/test_tcp_models.py index a552b3f8..6f2e60aa 100644 --- a/tests/test_tcp_models.py +++ b/tests/test_tcp_models.py @@ -3,6 +3,7 @@ from __future__ import annotations import struct +import time from socket import SHUT_RDWR, SO_LINGER, SOL_SOCKET, socket from threading import Thread from typing import TYPE_CHECKING @@ -403,13 +404,43 @@ def close_immediately(conn: socket) -> None: assert await client.inverter() assert ( - excinfo.value.args[0] + str(excinfo.value) == "Failed to communicate with the Omnik Inverter device over TCP" ) await server_exit +async def test_communication_timeout() -> None: + """Test on timed out connection attempt - TCP source.""" + serial_number = 1 + + def long_timeout(conn: socket) -> None: + """Close the connection and send RST.""" + time.sleep(0.2) + conn.close() + + (server_exit, port) = tcp_server(serial_number, long_timeout) + + client = OmnikInverter( + host="localhost", + source_type="tcp", + serial_number=serial_number, + tcp_port=port, + request_timeout=0.1, + ) + + with pytest.raises(OmnikInverterConnectionError) as excinfo: + assert await client.inverter() + + assert ( + str(excinfo.value) + == "Timeout occurred while communicating with the Omnik Inverter device" + ) + + await server_exit + + async def test_connection_failed() -> None: """Test on failed connection attempt - TCP source.""" serial_number = 1