From 2c390a0e283a6e0732412a53b86d8187031f79d1 Mon Sep 17 00:00:00 2001 From: Hanaasagi Date: Tue, 6 Jul 2021 17:58:18 +0000 Subject: [PATCH] Fix: fix the `ClientWebSocketResponse.receive` stuck for a long time if network is suddenly interrupted. Related to #2309; If we call `receive()` with a unlimited timeout, it will get stuck for long time, because `_pong_not_received` has no way to awake a block coroutine `receive()`. This PR add a `asyncio.Event` to awake this coroutine. --- CHANGES/2309.bugfix | 1 + aiohttp/client_ws.py | 27 +++++++++++++++++++- tests/test_client_ws_functional.py | 40 ++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 CHANGES/2309.bugfix diff --git a/CHANGES/2309.bugfix b/CHANGES/2309.bugfix new file mode 100644 index 00000000000..fe50d2db362 --- /dev/null +++ b/CHANGES/2309.bugfix @@ -0,0 +1 @@ +Fix the `ClientWebSocketResponse.receive` stuck for a long time if network is suddenly interrupted. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 17690f2a076..1c268b45ffb 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -78,6 +78,8 @@ def __init__( self._compress = compress self._client_notakeover = client_notakeover + # A flag to indicate whether the latest heartbeat failed. + self._is_heartbeat_failed = asyncio.Event() self._reset_heartbeat() def _cancel_heartbeat(self) -> None: @@ -91,6 +93,7 @@ def _cancel_heartbeat(self) -> None: def _reset_heartbeat(self) -> None: self._cancel_heartbeat() + self._is_heartbeat_failed.clear() if self._heartbeat is not None: self._heartbeat_cb = call_later( @@ -116,6 +119,7 @@ def _pong_not_received(self) -> None: self._close_code = WSCloseCode.ABNORMAL_CLOSURE self._exception = asyncio.TimeoutError() self._response.close() + self._is_heartbeat_failed.set() @property def closed(self) -> bool: @@ -235,11 +239,32 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: try: self._waiting = self._loop.create_future() + + read_task: asyncio.Task[WSMessage] = asyncio.create_task( + self._reader.read() + ) + is_heartbeat_failed_task: asyncio.Task[bool] = asyncio.create_task( + self._is_heartbeat_failed.wait() + ) try: async with async_timeout.timeout( timeout or self._timeout.ws_receive ): - msg = await self._reader.read() + # Check the heartbeat status when waiting data from server + done, pending = await asyncio.wait( + (read_task, is_heartbeat_failed_task), + return_when=asyncio.FIRST_COMPLETED, + ) + # If server doesn't pong, but return data normally, + # supress the exception. + if read_task in done: + is_heartbeat_failed_task.cancel() + msg = read_task.result() + elif is_heartbeat_failed_task in done: + read_task.cancel() + assert self._exception is not None + raise self._exception + self._reset_heartbeat() finally: waiter = self._waiting diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index ffd140c9447..a6975688e7b 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,6 +1,8 @@ # type: ignore import asyncio +import time from typing import Any +from unittest import mock import async_timeout import pytest @@ -619,6 +621,44 @@ async def handler(request): assert ping_received +async def test_heartbeat_network_down(aiohttp_client: Any) -> None: + pong_should_received = 0 + + async def handler(request): + nonlocal pong_should_received + ws = web.WebSocketResponse(autoping=False) + await ws.prepare(request) + start_at = time.time() + async for msg in ws: + if time.time() - start_at >= 0.2: + # Server is not responding + await asyncio.sleep(0.5) + break + if msg.type == aiohttp.WSMsgType.PING: + await ws.pong() + pong_should_received += 1 + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + resp = await client.ws_connect("/", heartbeat=0.05, autoping=False) + + pong_received = 0 + # keep the connection + with mock.patch("aiohttp.streams.DataQueue.feed_eof"): + with pytest.raises(asyncio.TimeoutError): + async for msg in resp: + if msg.type == aiohttp.WSMsgType.PONG: + pong_received += 1 + continue + + assert pong_should_received == pong_received + await resp.close() + + async def test_send_recv_compress(aiohttp_client: Any) -> None: async def handler(request): ws = web.WebSocketResponse()