Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancellations being swallowed #9030

Merged
merged 17 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9030.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`.
37 changes: 29 additions & 8 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,8 @@
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
return
await writer.drain()
await self._continue
Dismissed Show dismissed Hide dismissed

protocol = conn.protocol
assert protocol is not None
Expand Down Expand Up @@ -590,6 +587,7 @@
except asyncio.CancelledError:
# Body hasn't been fully sent, so connection can't be reused.
conn.close()
raise
except Exception as underlying_exc:
set_exception(
protocol,
Expand Down Expand Up @@ -696,8 +694,15 @@

async def close(self) -> None:
if self._writer is not None:
with contextlib.suppress(asyncio.CancelledError):
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

def terminate(self) -> None:
if self._writer is not None:
Expand Down Expand Up @@ -1040,7 +1045,15 @@

async def _wait_released(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

Check warning on line 1056 in aiohttp/client_reqrep.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/client_reqrep.py#L1056

Added line #L1056 was not covered by tests
self._release_connection()

def _cleanup_writer(self) -> None:
Expand All @@ -1057,7 +1070,15 @@

async def wait_for_close(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

Check warning on line 1081 in aiohttp/client_reqrep.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/client_reqrep.py#L1081

Added line #L1081 was not covered by tests
self.release()

async def read(self) -> bytes:
Expand Down
38 changes: 29 additions & 9 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,32 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
# down while the handler is still processing a request
# to avoid creating a future for every request.
self._handler_waiter = self._loop.create_future()
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._handler_waiter = None
if (
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if self._task_handler is not None and not self._task_handler.done():
await self._task_handler
await asyncio.shield(self._task_handler)
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

# force-close non-idle handler
if self._task_handler is not None:
Expand Down Expand Up @@ -534,8 +549,6 @@ async def start(self) -> None:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
finally:
self._waiter = None

Expand All @@ -562,7 +575,7 @@ async def start(self) -> None:
task = loop.create_task(coro)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break

Expand All @@ -588,12 +601,19 @@ async def start(self) -> None:
now = loop.time()
end_t = now + lingering_time

with suppress(asyncio.TimeoutError, asyncio.CancelledError):
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise

# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
Expand All @@ -603,8 +623,8 @@ async def start(self) -> None:
payload.set_exception(_PAYLOAD_ACCESS_ERROR)

except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection ")
break
self.log_debug("Ignored premature client disconnection")
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
Expand Down
17 changes: 17 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import io
import pathlib
import sys
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Protocol
Expand Down Expand Up @@ -1234,6 +1235,22 @@ async def test_oserror_on_write_bytes(
assert isinstance(exc, aiohttp.ClientOSError)


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
req._writer = asyncio.Future() # type: ignore[assignment]

t = asyncio.create_task(req.close())

# Start waiting on _writer
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed.
with pytest.raises(asyncio.CancelledError):
await t
Dismissed Show dismissed Hide dismissed


async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)

Expand Down
35 changes: 35 additions & 0 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import pathlib
import socket
import sys
import zlib
from typing import AsyncIterator, Awaitable, Callable, Dict, List, NoReturn, Optional
from unittest import mock
Expand Down Expand Up @@ -195,6 +196,40 @@ async def handler(request: web.Request) -> web.Response:
resp.release()


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_shutdown(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
t = asyncio.create_task(request.protocol.shutdown())
# Ensure it's started waiting
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed
with pytest.raises(asyncio.CancelledError):
await t
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed

# Repeat for second waiter in shutdown()
with mock.patch.object(request.protocol, "_request_in_progress", False):
with mock.patch.object(request.protocol, "_current_request", None):
t = asyncio.create_task(request.protocol.shutdown())
await asyncio.sleep(0)

t.cancel()
with pytest.raises(asyncio.CancelledError):
await t
Dismissed Show dismissed Hide dismissed

return web.Response(body=b"OK")

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/") as resp:
assert resp.status == 200
txt = await resp.text()
assert txt == "OK"


async def test_post_form(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
data = await request.post()
Expand Down
Loading