Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shutdown logic: Only wait on handlers #8495

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,23 +300,6 @@ async def _run_app(
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
async def wait(
starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float
) -> None:
# Wait for pending tasks for a given time limit.
t = asyncio.current_task()
assert t is not None
starting_tasks.add(t)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout)

async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
t = asyncio.current_task()
assert t is not None
exclude.add(t)
while tasks := asyncio.all_tasks().difference(exclude):
await asyncio.wait(tasks)

# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
Expand All @@ -335,12 +318,6 @@ async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None:
)

await runner.setup()
# On shutdown we want to avoid waiting on tasks which run forever.
# It's very likely that all tasks which run forever will have been created by
# the time we have completed the application startup (in runner.setup()),
# so we just record all running tasks here and exclude them later.
starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks())
runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout)

sites: List[BaseSite] = []

Expand Down
8 changes: 6 additions & 2 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
if self._waiter:
self._waiter.cancel()

# wait for handlers
# Wait for graceful disconnection
if self._current_request is not None:
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
await self._current_request.wait_for_disconnection()
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with ceil_timeout(timeout):
if self._current_request is not None:
Expand Down Expand Up @@ -461,7 +466,6 @@ async def _handle_request(
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
try:
self._current_request = request
Expand Down
14 changes: 1 addition & 13 deletions aiohttp/web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,7 @@ async def start(self) -> None:


class BaseRunner(ABC):
__slots__ = (
"shutdown_callback",
"_handle_signals",
"_kwargs",
"_server",
"_sites",
"_shutdown_timeout",
)
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")

def __init__(
self,
Expand All @@ -246,7 +239,6 @@ def __init__(
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
Expand Down Expand Up @@ -304,10 +296,6 @@ async def cleanup(self) -> None:
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()

if self.shutdown_callback:
await self.shutdown_callback()

await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()

Expand Down
8 changes: 7 additions & 1 deletion aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def connection_lost(
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
del self._connections[handler]
if handler._task_handler:
handler._task_handler.add_done_callback(
lambda f: self._connections.pop(handler, None)
)
else:
del self._connections[handler]

def _make_request(
self,
Expand All @@ -69,6 +74,7 @@ def pre_shutdown(self) -> None:
async def shutdown(self, timeout: Optional[float] = None) -> None:
coros = (conn.shutdown(timeout) for conn in self._connections)
await asyncio.gather(*coros)
print("LENGTH", len(self._connections))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to assert the number of connections at this point, but can't think of a good way to get this in the tests. Any ideas?

Copy link
Member

@bdraco bdraco Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe patch the shutdown function and see how many calls are made to it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that works. The number of shutdown() calls doesn't matter, it's the number of them which didn't result in the connections being removed from the set as part of the done callback.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if I can figure out how to get a reference to the Server object, then I could patch the .clear() call to check how many are still in the set when this is called.

self._connections.clear()

def __call__(self) -> RequestHandler:
Expand Down
43 changes: 10 additions & 33 deletions tests/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

from aiohttp import ClientConnectorError, ClientSession, WSCloseCode, web
from aiohttp import ClientConnectorError, ClientSession, ClientTimeout, WSCloseCode, web
from aiohttp.test_utils import make_mocked_coro
from aiohttp.web_runner import BaseRunner

Expand Down Expand Up @@ -935,8 +935,12 @@
async with ClientSession() as sess:
for _ in range(5): # pragma: no cover
try:
async with sess.get(f"http://localhost:{port}/"):
pass
with pytest.raises(asyncio.TimeoutError):
async with sess.get(
f"http://localhost:{port}/",
timeout=ClientTimeout(total=0.1),
):
pass
except ClientConnectorError:
await asyncio.sleep(0.5)
else:
Expand All @@ -956,6 +960,7 @@
async def handler(request: web.Request) -> web.Response:
nonlocal t
t = asyncio.create_task(task())
await t
Dismissed Show dismissed Hide dismissed
return web.Response(text="FOO")

t = test_task = None
Expand All @@ -968,7 +973,7 @@
assert test_task.exception() is None
return t

def test_shutdown_wait_for_task(
def test_shutdown_wait_for_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -985,7 +990,7 @@
assert t.done()
assert not t.cancelled()

def test_shutdown_timeout_task(
def test_shutdown_timeout_handler(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
Expand All @@ -1002,34 +1007,6 @@
assert t.done()
assert t.cancelled()

def test_shutdown_wait_for_spawned_task(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
port = aiohttp_unused_port()
finished = False
finished_sub = False
sub_t = None

async def sub_task():
nonlocal finished_sub
await asyncio.sleep(1.5)
finished_sub = True

async def task():
nonlocal finished, sub_t
await asyncio.sleep(0.5)
sub_t = asyncio.create_task(sub_task())
finished = True

t = self.run_app(port, 3, task)

assert finished is True
assert t.done()
assert not t.cancelled()
assert finished_sub is True
assert sub_t.done()
assert not sub_t.cancelled()

def test_shutdown_timeout_not_reached(
self, aiohttp_unused_port: Callable[[], int]
) -> None:
Expand Down
Loading