Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8685/e7c02ca4 backport][3.11] Fix exceptions from WebSocket ping task not being consumed #8730

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/8685.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed unconsumed exceptions raised by the WebSocket heartbeat -- by :user:`bdraco`.

If the heartbeat ping raised an exception, it would not be consumed and would be logged as an warning.
25 changes: 16 additions & 9 deletions aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,28 @@ def _send_heartbeat(self) -> None:
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None

def _pong_not_received(self) -> None:
if not self._closed:
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = ServerTimeoutError()
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(
WSMessage(WSMsgType.ERROR, self._exception, None)
)
self._handle_ping_pong_exception(ServerTimeoutError())

def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down
18 changes: 15 additions & 3 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,28 @@ def _send_heartbeat(self) -> None:
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)

def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None

def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = asyncio.TimeoutError()
self._handle_ping_pong_exception(asyncio.TimeoutError())

def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = exc
if self._waiting and not self._closing and self._reader is not None:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))

def _set_closed(self) -> None:
"""Set the connection to closed.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,36 @@ async def handler(request):
assert ping_received


async def test_heartbeat_connection_closed(aiohttp_client: AiohttpClient) -> None:
"""Test that the connection is closed while ping is in progress."""

async def handler(request: web.Request) -> NoReturn:
ws = web.WebSocketResponse(autoping=False)
await ws.prepare(request)
await ws.receive()
assert False

app = web.Application()
app.router.add_route("GET", "/", handler)

client = await aiohttp_client(app)
resp = await client.ws_connect("/", heartbeat=0.1)
ping_count = 0
# We patch write here to simulate a connection reset error
# since if we closed the connection normally, the client would
# would cancel the heartbeat task and we wouldn't get a ping
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
await resp.receive()
ping_count = ping.call_count
# Connection should be closed roughly after 1.5x heartbeat.
await asyncio.sleep(0.2)
assert ping_count == 1
assert resp.close_code is WSCloseCode.ABNORMAL_CLOSURE


async def test_heartbeat_no_pong(aiohttp_client: AiohttpClient) -> None:
"""Test that the connection is closed if no pong is received without sending messages."""
ping_received = False
Expand Down
73 changes: 72 additions & 1 deletion tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import contextlib
import sys
import weakref
from typing import Any, Optional
from typing import Any, NoReturn, Optional
from unittest import mock

import pytest

Expand Down Expand Up @@ -724,6 +725,76 @@ async def handler(request):
await ws.close()


async def test_heartbeat_connection_closed(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that the connection is closed while ping is in progress."""
ping_count = 0

async def handler(request: web.Request) -> NoReturn:
nonlocal ping_count
ws_server = web.WebSocketResponse(heartbeat=0.05)
await ws_server.prepare(request)
# We patch write here to simulate a connection reset error
# since if we closed the connection normally, the server would
# would cancel the heartbeat task and we wouldn't get a ping
with mock.patch.object(
ws_server._req.transport, "write", side_effect=ConnectionResetError
), mock.patch.object(
ws_server._writer, "ping", wraps=ws_server._writer.ping
) as ping:
try:
await ws_server.receive()
finally:
ping_count = ping.call_count
assert False

app = web.Application()
app.router.add_get("/", handler)

client = await aiohttp_client(app)
ws = await client.ws_connect("/", autoping=False)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED
assert msg.extra is None
assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE
assert ping_count == 1
await ws.close()


async def test_heartbeat_failure_ends_receive(
loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient
) -> None:
"""Test that no heartbeat response to the server ends the receive call."""
ws_server_close_code = None
ws_server_exception = None

async def handler(request: web.Request) -> NoReturn:
nonlocal ws_server_close_code, ws_server_exception
ws_server = web.WebSocketResponse(heartbeat=0.05)
await ws_server.prepare(request)
try:
await ws_server.receive()
finally:
ws_server_close_code = ws_server.close_code
ws_server_exception = ws_server.exception()
assert False

app = web.Application()
app.router.add_get("/", handler)

client = await aiohttp_client(app)
ws = await client.ws_connect("/", autoping=False)
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.PING
msg = await ws.receive()
assert msg.type is aiohttp.WSMsgType.CLOSED
assert ws.close_code == WSCloseCode.ABNORMAL_CLOSURE
assert ws_server_close_code == WSCloseCode.ABNORMAL_CLOSURE
assert isinstance(ws_server_exception, asyncio.TimeoutError)
await ws.close()


async def test_heartbeat_no_pong_send_many_messages(
loop: Any, aiohttp_client: Any
) -> None:
Expand Down
Loading