Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancellations being swallowed #9030

Merged
merged 17 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9030.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`.
37 changes: 29 additions & 8 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,8 @@
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
return
await writer.drain()
await self._continue
Dismissed Show dismissed Hide dismissed

protocol = conn.protocol
assert protocol is not None
Expand Down Expand Up @@ -594,6 +591,7 @@
except asyncio.CancelledError:
# Body hasn't been fully sent, so connection can't be reused.
conn.close()
raise
except Exception as underlying_exc:
set_exception(
protocol,
Expand Down Expand Up @@ -696,8 +694,15 @@

async def close(self) -> None:
if self._writer is not None:
with contextlib.suppress(asyncio.CancelledError):
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

def terminate(self) -> None:
if self._writer is not None:
Expand Down Expand Up @@ -1040,7 +1045,15 @@

async def _wait_released(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
self._release_connection()

def _cleanup_writer(self) -> None:
Expand All @@ -1057,7 +1070,15 @@

async def wait_for_close(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
self.release()

async def read(self) -> bytes:
Expand Down
36 changes: 27 additions & 9 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,31 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:

# Wait for graceful handler completion
if self._handler_waiter is not None:
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if self._task_handler is not None and not self._task_handler.done():
await self._task_handler
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

# force-close non-idle handler
if self._task_handler is not None:
Expand Down Expand Up @@ -523,8 +537,6 @@ async def start(self) -> None:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
finally:
self._waiter = None

Expand All @@ -551,7 +563,7 @@ async def start(self) -> None:
task = loop.create_task(coro)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break

Expand All @@ -577,23 +589,29 @@ async def start(self) -> None:
now = loop.time()
end_t = now + lingering_time

with suppress(asyncio.TimeoutError, asyncio.CancelledError):
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise

# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug("Uncompleted request.")
self.close()

set_exception(payload, PayloadAccessError())

except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection ")
break
self.log_debug("Ignored premature client disconnection")
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
Expand Down
17 changes: 17 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import io
import pathlib
import sys
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Protocol
Expand Down Expand Up @@ -1234,6 +1235,22 @@
assert isinstance(exc, aiohttp.ClientOSError)


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
req._writer = asyncio.Future() # type: ignore[assignment]

t = asyncio.create_task(req.close())

# Start waiting on _writer
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed.
with pytest.raises(asyncio.CancelledError):
await t
Dismissed Show dismissed Hide dismissed


async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)

Expand Down
40 changes: 40 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import pathlib
import socket
import sys
import zlib
from typing import Any, NoReturn, Optional
from unittest import mock
Expand Down Expand Up @@ -188,6 +189,45 @@
resp.release()


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_shutdown(aiohttp_client: Any) -> None:
async def handler(request):
t = asyncio.create_task(request.protocol.shutdown())
# Ensure it's started waiting
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed
with pytest.raises(asyncio.CancelledError):
await t
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

# Repeat for second waiter in shutdown()
with mock.patch.object(request.protocol, "_handler_waiter", None):
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
with mock.patch.object(
request.protocol,
"_current_request",
autospec=True,
spec_set=True,
):
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
t = asyncio.create_task(request.protocol.shutdown())
await asyncio.sleep(0)

t.cancel()
with pytest.raises(asyncio.CancelledError):
await t
Dismissed Show dismissed Hide dismissed

return web.Response(body=b"OK")

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/") as resp:
assert resp.status == 200
txt = await resp.text()
assert txt == "OK"


async def test_post_form(aiohttp_client: Any) -> None:
async def handler(request):
data = await request.post()
Expand Down
Loading