Skip to content

Commit

Permalink
Fix: fix the ClientWebSocketResponse.receive stuck for a long time …
Browse files Browse the repository at this point in the history
…if network is suddenly interrupted.

Related to aio-libs#2309; If we call `receive()` with a unlimited timeout,
it will get stuck for long time, because `_pong_not_received` has no way
to awake a block coroutine `receive()`. This PR add a `asyncio.Event` to
awake this coroutine.
  • Loading branch information
Hanaasagi committed Jul 8, 2021
1 parent 7080a8b commit 2c390a0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES/2309.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the `ClientWebSocketResponse.receive` stuck for a long time if network is suddenly interrupted.
27 changes: 26 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(
self._compress = compress
self._client_notakeover = client_notakeover

# A flag to indicate whether the latest heartbeat failed.
self._is_heartbeat_failed = asyncio.Event()
self._reset_heartbeat()

def _cancel_heartbeat(self) -> None:
Expand All @@ -91,6 +93,7 @@ def _cancel_heartbeat(self) -> None:

def _reset_heartbeat(self) -> None:
self._cancel_heartbeat()
self._is_heartbeat_failed.clear()

if self._heartbeat is not None:
self._heartbeat_cb = call_later(
Expand All @@ -116,6 +119,7 @@ def _pong_not_received(self) -> None:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = asyncio.TimeoutError()
self._response.close()
self._is_heartbeat_failed.set()

@property
def closed(self) -> bool:
Expand Down Expand Up @@ -235,11 +239,32 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:

try:
self._waiting = self._loop.create_future()

read_task: asyncio.Task[WSMessage] = asyncio.create_task(
self._reader.read()
)
is_heartbeat_failed_task: asyncio.Task[bool] = asyncio.create_task(
self._is_heartbeat_failed.wait()
)
try:
async with async_timeout.timeout(
timeout or self._timeout.ws_receive
):
msg = await self._reader.read()
# Check the heartbeat status when waiting data from server
done, pending = await asyncio.wait(
(read_task, is_heartbeat_failed_task),
return_when=asyncio.FIRST_COMPLETED,
)
# If server doesn't pong, but return data normally,
# supress the exception.
if read_task in done:
is_heartbeat_failed_task.cancel()
msg = read_task.result()
elif is_heartbeat_failed_task in done:
read_task.cancel()
assert self._exception is not None
raise self._exception

self._reset_heartbeat()
finally:
waiter = self._waiting
Expand Down
40 changes: 40 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# type: ignore
import asyncio
import time
from typing import Any
from unittest import mock

import async_timeout
import pytest
Expand Down Expand Up @@ -619,6 +621,44 @@ async def handler(request):
assert ping_received


async def test_heartbeat_network_down(aiohttp_client: Any) -> None:
pong_should_received = 0

async def handler(request):
nonlocal pong_should_received
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
start_at = time.time()
async for msg in ws:
if time.time() - start_at >= 0.2:
# Server is not responding
await asyncio.sleep(0.5)
break
if msg.type == aiohttp.WSMsgType.PING:
await ws.pong()
pong_should_received += 1
await ws.close()
return ws

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.05, autoping=False)

pong_received = 0
# keep the connection
with mock.patch("aiohttp.streams.DataQueue.feed_eof"):
with pytest.raises(asyncio.TimeoutError):
async for msg in resp:
if msg.type == aiohttp.WSMsgType.PONG:
pong_received += 1
continue

assert pong_should_received == pong_received
await resp.close()


async def test_send_recv_compress(aiohttp_client: Any) -> None:
async def handler(request):
ws = web.WebSocketResponse()
Expand Down

0 comments on commit 2c390a0

Please sign in to comment.