From 3b638872f4896b191faee9ed8c37f1ad1fe6d440 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 10 Feb 2023 14:11:08 -0600 Subject: [PATCH 1/5] Fix ConnectionResetError not being raised when the transport is closed (#7180) `ConnectionResetError` will always be raised when `StreamWriter.write` is called after `connection_lost` has been called on the `BaseProtocol` Restores pre 3.8.3 behavior fixes #7172 - [x] I think the code is well written - [x] Unit tests for the changes exist - [x] Documentation reflects the changes - [x] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is <Name> <Surname>. * Please keep alphabetical order, the file is sorted by names. - [x] Add a new news fragment into the `CHANGES` folder * name it `.` for example (588.bugfix) * if you don't have an `issue_id` change it to the pr id after creating the pr * ensure type is one of the following: * `.feature`: Signifying a new feature. * `.bugfix`: Signifying a bug fix. * `.doc`: Signifying a documentation improvement. * `.removal`: Signifying a deprecation or removal of public API. * `.misc`: A ticket has been closed, but it is not of interest to users. * Make sure to use full sentences with correct case and punctuation, for example: "Fix issue with non-ascii contents in doctest text files." --------- Co-authored-by: Sviatoslav Sydorenko Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sam Bull (cherry picked from commit 974323f63ba03bf720658cd738fc8722182322ae) --- CHANGES/7180.bugfix | 1 + aiohttp/base_protocol.py | 9 ++++++--- aiohttp/http_writer.py | 10 ++++------ tests/test_base_protocol.py | 8 ++++---- tests/test_client_proto.py | 14 ++++++++++++++ tests/test_http_writer.py | 17 +++++++++++++++++ 6 files changed, 46 insertions(+), 13 deletions(-) create mode 100644 CHANGES/7180.bugfix diff --git a/CHANGES/7180.bugfix b/CHANGES/7180.bugfix new file mode 100644 index 00000000000..66980638868 --- /dev/null +++ b/CHANGES/7180.bugfix @@ -0,0 +1 @@ +``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol`` diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 8189835e211..4c9f0a752e3 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._connection_lost = False self._reading_paused = False self.transport: Optional[asyncio.Transport] = None + @property + def connected(self) -> bool: + """Return True if the connection is open.""" + return self.transport is not None + def pause_writing(self) -> None: assert not self._paused self._paused = True @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = tr def connection_lost(self, exc: Optional[BaseException]) -> None: - self._connection_lost = True # Wake up the writer if currently paused. self.transport = None if not self._paused: @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: waiter.set_exception(exc) async def _drain_helper(self) -> None: - if self._connection_lost: + if not self.connected: raise ConnectionResetError("Connection lost") if not self._paused: return diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index db3d6a04897..73f0f96f0ae 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -35,7 +35,6 @@ def __init__( on_headers_sent: _T_OnHeadersSent = None, ) -> None: self._protocol = protocol - self._transport = protocol.transport self.loop = loop self.length = None @@ -52,7 +51,7 @@ def __init__( @property def transport(self) -> Optional[asyncio.Transport]: - return self._transport + return self._protocol.transport @property def protocol(self) -> BaseProtocol: @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None: size = len(chunk) self.buffer_size += size self.output_size += size - - if self._transport is None or self._transport.is_closing(): + transport = self.transport + if not self._protocol.connected or transport is None or transport.is_closing(): raise ConnectionResetError("Cannot write to closing transport") - self._transport.write(chunk) + transport.write(chunk) async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None: await self.drain() self._eof = True - self._transport = None async def drain(self) -> None: """Flush the write buffer. diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index f3b966bff54..a16b1f10cb1 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -45,10 +45,10 @@ async def test_connection_lost_not_paused() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected async def test_connection_lost_paused_without_waiter() -> None: @@ -56,11 +56,11 @@ async def test_connection_lost_paused_without_waiter() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.pause_writing() pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected async def test_drain_lost() -> None: diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 85225c77dad..6be01c1d6f4 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -134,3 +134,17 @@ async def test_eof_received(loop) -> None: assert proto._read_timeout_handle is not None proto.eof_received() assert proto._read_timeout_handle is None + + +async def test_connection_lost_sets_transport_to_none(loop: Any, mocker: Any) -> None: + """Ensure that the transport is set to None when the connection is lost. + + This ensures the writer knows that the connection is closed. + """ + proto = ResponseHandler(loop=loop) + proto.connection_made(mocker.Mock()) + assert proto.transport is not None + + proto.connection_lost(OSError()) + + assert proto.transport is None diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 8ebcfc654a5..77b1c1b1452 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -236,6 +236,23 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None: await msg.write(b"After closing") +async def test_write_to_closed_transport( + protocol: Any, transport: Any, loop: Any +) -> None: + """Test that writing to a closed transport raises ConnectionResetError. + + The StreamWriter checks to see if protocol.transport is None before + writing to the transport. If it is None, it raises ConnectionResetError. + """ + msg = http.StreamWriter(protocol, loop) + + await msg.write(b"Before transport close") + protocol.transport = None + + with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"): + await msg.write(b"After transport closed") + + async def test_drain(protocol, transport, loop) -> None: msg = http.StreamWriter(protocol, loop) await msg.drain() From 67cfcb15724d1b0e12c3f70963430987c6822ccb Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 10 Feb 2023 20:21:16 +0000 Subject: [PATCH 2/5] Update test_http_writer.py --- tests/test_http_writer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 77b1c1b1452..5649f32f792 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -236,9 +236,7 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None: await msg.write(b"After closing") -async def test_write_to_closed_transport( - protocol: Any, transport: Any, loop: Any -) -> None: +async def test_write_to_closed_transport(protocol, transport, loop) -> None: """Test that writing to a closed transport raises ConnectionResetError. The StreamWriter checks to see if protocol.transport is None before From 541ef5882dad96488b578e9a969da0bea4f165fd Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 10 Feb 2023 20:21:31 +0000 Subject: [PATCH 3/5] Update test_client_proto.py --- tests/test_client_proto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 6be01c1d6f4..eea2830246a 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -136,7 +136,7 @@ async def test_eof_received(loop) -> None: assert proto._read_timeout_handle is None -async def test_connection_lost_sets_transport_to_none(loop: Any, mocker: Any) -> None: +async def test_connection_lost_sets_transport_to_none(loop, mocker) -> None: """Ensure that the transport is set to None when the connection is lost. This ensures the writer knows that the connection is closed. From f0d4b51f100eee8028fbb0fffcd790f9acc6cec3 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 10 Feb 2023 22:55:48 +0000 Subject: [PATCH 4/5] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 05f4eb33ec4..10f6c357083 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -117,7 +117,7 @@ jobs: needs: gen_llhttp strategy: matrix: - pyver: [3.6, 3.7, 3.8, 3.9, '3.10'] + pyver: ['3.6.15', 3.7, 3.8, 3.9, '3.10'] no-extensions: ['', 'Y'] os: [ubuntu, macos, windows] exclude: From 32845f3bc0e8905581a5b4e785534f30e2a67f73 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 10 Feb 2023 23:02:02 +0000 Subject: [PATCH 5/5] Update ci.yml --- .github/workflows/ci.yml | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10f6c357083..8e61bf0d64c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -117,38 +117,38 @@ jobs: needs: gen_llhttp strategy: matrix: - pyver: ['3.6.15', 3.7, 3.8, 3.9, '3.10'] + pyver: [3.6, 3.7, 3.8, 3.9, '3.10'] no-extensions: ['', 'Y'] - os: [ubuntu, macos, windows] + os: [ubuntu-20.04, macos-latest, windows-latest] exclude: - - os: macos + - os: macos-latest no-extensions: 'Y' - - os: macos + - os: macos-latest pyver: 3.7 - - os: macos + - os: macos-latest pyver: 3.8 - - os: windows + - os: windows-latest no-extensions: 'Y' experimental: [false] include: - pyver: pypy-3.8 no-extensions: 'Y' - os: ubuntu + os: ubuntu-latest experimental: false - - os: macos + - os: macos-latest pyver: "3.11.0-alpha - 3.11.0" experimental: true no-extensions: 'Y' - - os: ubuntu + - os: ubuntu-latest pyver: "3.11.0-alpha - 3.11.0" experimental: false no-extensions: 'Y' - - os: windows + - os: windows-latest pyver: "3.11.0-alpha - 3.11.0" experimental: true no-extensions: 'Y' fail-fast: true - runs-on: ${{ matrix.os }}-latest + runs-on: ${{ matrix.os }} continue-on-error: ${{ matrix.experimental }} steps: - name: Checkout