From cb94533dd50426809b7fcbb8bbad0ef17509de5c Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 12 Nov 2023 23:58:15 +0000 Subject: [PATCH] Ensure writer is always reset on completion (#7815) (#7826) (cherry picked from commit 8f2f048ed7b0a01630ba620c521f12c673a006e3) --- CHANGES/7815.bugfix | 1 + aiohttp/client_reqrep.py | 74 +++++++++++++++++++++++------------ tests/test_client_request.py | 20 ++++++++-- tests/test_client_response.py | 4 ++ tests/test_proxy.py | 18 ++++----- 5 files changed, 79 insertions(+), 38 deletions(-) create mode 100644 CHANGES/7815.bugfix diff --git a/CHANGES/7815.bugfix b/CHANGES/7815.bugfix new file mode 100644 index 00000000000..269c2680d0b --- /dev/null +++ b/CHANGES/7815.bugfix @@ -0,0 +1 @@ +Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 46f01a071f9..7f1f244639a 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -53,7 +53,13 @@ reify, set_result, ) -from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter +from .http import ( + SERVER_SOFTWARE, + HttpVersion, + HttpVersion10, + HttpVersion11, + StreamWriter, +) from .log import client_logger from .streams import StreamReader from .typedefs import ( @@ -241,7 +247,7 @@ class ClientRequest: auth = None response = None - _writer = None # async task for streaming data + __writer = None # async task for streaming data _continue = None # waiter future for '100 Continue' response # N.B. @@ -332,6 +338,21 @@ def __init__( traces = [] self._traces = traces + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @@ -625,8 +646,6 @@ async def write_bytes( else: await writer.write_eof() protocol.start_timeout() - finally: - self._writer = None async def send(self, conn: "Connection") -> "ClientResponse": # Specify request target: @@ -711,16 +730,14 @@ async def send(self, conn: "Connection") -> "ClientResponse": async def close(self) -> None: if self._writer is not None: - try: - with contextlib.suppress(asyncio.CancelledError): - await self._writer - finally: - self._writer = None + with contextlib.suppress(asyncio.CancelledError): + await self._writer def terminate(self) -> None: if self._writer is not None: if not self.loop.is_closed(): self._writer.cancel() + self._writer.remove_done_callback(self.__reset_writer) self._writer = None async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: @@ -740,9 +757,9 @@ class ClientResponse(HeadersMixin): # but will be set by the start() method. # As the end user will likely never see the None values, we cheat the types below. # from the Status-Line of the response - version = None # HTTP-Version - status: int = None # type: ignore[assignment] # Status-Code - reason = None # Reason-Phrase + version: Optional[HttpVersion] = None # HTTP-Version + status: int = None # type: ignore[assignment] # Status-Code + reason: Optional[str] = None # Reason-Phrase content: StreamReader = None # type: ignore[assignment] # Payload stream _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] @@ -754,6 +771,7 @@ class ClientResponse(HeadersMixin): # post-init stage allows to not change ctor signature _closed = True # to allow __del__ for non-initialized properly response _released = False + __writer = None def __init__( self, @@ -799,6 +817,21 @@ def __init__( if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + @reify def url(self) -> URL: return self._url @@ -863,7 +896,7 @@ def __repr__(self) -> str: "ascii", "backslashreplace" ).decode("ascii") else: - ascii_encodable_reason = self.reason + ascii_encodable_reason = "None" print( "".format( ascii_encodable_url, self.status, ascii_encodable_reason @@ -1044,18 +1077,12 @@ def _release_connection(self) -> None: async def _wait_released(self) -> None: if self._writer is not None: - try: - await self._writer - finally: - self._writer = None + await self._writer self._release_connection() def _cleanup_writer(self) -> None: if self._writer is not None: - if self._writer.done(): - self._writer = None - else: - self._writer.cancel() + self._writer.cancel() self._session = None def _notify_content(self) -> None: @@ -1066,10 +1093,7 @@ def _notify_content(self) -> None: async def wait_for_close(self) -> None: if self._writer is not None: - try: - await self._writer - finally: - self._writer = None + await self._writer self.release() async def read(self) -> bytes: diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 0794bfd601d..0f58d752de2 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -5,7 +5,7 @@ import urllib.parse import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional from unittest import mock import pytest @@ -24,6 +24,17 @@ from aiohttp.test_utils import make_mocked_coro +class WriterMock(mock.AsyncMock): + def __await__(self) -> None: + return self().__await__() + + def add_done_callback(self, cb: Callable[[], None]) -> None: + """Dummy method.""" + + def remove_done_callback(self, cb: Callable[[], None]) -> None: + """Dummy method.""" + + @pytest.fixture def make_request(loop): request = None @@ -1167,7 +1178,7 @@ def read(self, decode=False): async def test_oserror_on_write_bytes(loop, conn) -> None: req = ClientRequest("POST", URL("http://python.org/"), loop=loop) - writer = mock.Mock() + writer = WriterMock() writer.write.side_effect = OSError await req.write_bytes(writer, conn) @@ -1183,7 +1194,8 @@ async def test_terminate(loop, conn) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) resp = await req.send(conn) assert req._writer is not None - writer = req._writer = mock.Mock() + writer = req._writer = WriterMock() + writer.cancel = mock.Mock() req.terminate() assert req._writer is None @@ -1201,7 +1213,7 @@ async def go(): req = ClientRequest("get", URL("http://python.org")) resp = await req.send(conn) assert req._writer is not None - writer = req._writer = mock.Mock() + writer = req._writer = WriterMock() await asyncio.sleep(0.05) diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 74027fcaf76..166089cc84a 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -2,6 +2,7 @@ import gc import sys +from typing import Callable from unittest import mock import pytest @@ -19,6 +20,9 @@ class WriterMock(mock.AsyncMock): def __await__(self) -> None: return self().__await__() + def add_done_callback(self, cb: Callable[[], None]) -> None: + cb() + def done(self) -> bool: return True diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 47aff68c98f..1ff53e3f899 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -202,7 +202,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -264,7 +264,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -326,7 +326,7 @@ def test_https_connect(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -386,7 +386,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -440,7 +440,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -496,7 +496,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -555,7 +555,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -666,7 +666,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -737,7 +737,7 @@ def test_https_auth(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[],