diff --git a/CHANGES/2958.misc b/CHANGES/2958.misc new file mode 100644 index 00000000000..bf75ef12dce --- /dev/null +++ b/CHANGES/2958.misc @@ -0,0 +1 @@ +Simplify StreamWriter constructor. \ No newline at end of file diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 249cacc2137..f1cebd4d97c 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -495,7 +495,7 @@ async def send(self, conn): path += '?' + self.url.raw_query_string writer = StreamWriter( - conn.protocol, conn.transport, self.loop, + conn.protocol, self.loop, on_chunk_sent=self._on_chunk_request_sent ) diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index bcb9b25cb32..879f1f47f59 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -17,9 +17,9 @@ class StreamWriter(AbstractStreamWriter): - def __init__(self, protocol, transport, loop, on_chunk_sent=None): + def __init__(self, protocol, loop, on_chunk_sent=None): self._protocol = protocol - self._transport = transport + self._transport = protocol.transport self.loop = loop self.length = None diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index fc54138df3c..c6c1ada9ac3 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -504,12 +504,13 @@ def make_mocked_request(method, path, headers=None, *, if app is None: app = _create_app_mock() - if protocol is sentinel: - protocol = mock.Mock() - if transport is sentinel: transport = _create_transport(sslcontext) + if protocol is sentinel: + protocol = mock.Mock() + protocol.transport = transport + if writer is sentinel: writer = mock.Mock() writer.write_headers = make_mocked_coro(None) diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 7475501a5c4..85eb6160a78 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -113,7 +113,6 @@ async def _sendfile_system(self, request, fobj, count): else: writer = SendfileStreamWriter( request.protocol, - transport, request.loop ) request._payload_writer = writer diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 008d8b114b3..32283e3ac7a 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -237,14 +237,14 @@ def data_received(self, data): # something happened during parsing self._error_handler = self._loop.create_task( self.handle_parse_error( - StreamWriter(self, self.transport, self._loop), + StreamWriter(self, self._loop), 400, exc, exc.message)) self.close() except Exception as exc: # 500: internal error self._error_handler = self._loop.create_task( self.handle_parse_error( - StreamWriter(self, self.transport, self._loop), + StreamWriter(self, self._loop), 500, exc)) self.close() else: @@ -377,7 +377,7 @@ async def start(self): now = loop.time() manager.requests_count += 1 - writer = StreamWriter(self, self.transport, loop) + writer = StreamWriter(self, loop) request = self._request_factory( message, payload, self, writer, handler) try: diff --git a/tests/test_client_request.py b/tests/test_client_request.py index c95419bbcbd..22374f03fdf 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -41,8 +41,9 @@ def buf(): @pytest.fixture -def protocol(loop): +def protocol(loop, transport): protocol = mock.Mock() + protocol.transport = transport protocol._drain_helper.return_value = loop.create_future() protocol._drain_helper.return_value.set_result(None) return protocol diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 7f9d17dea34..897da4c1ab4 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -34,14 +34,14 @@ def protocol(loop, transport): def test_payloadwriter_properties(transport, protocol, loop): - writer = http.StreamWriter(protocol, transport, loop) + writer = http.StreamWriter(protocol, loop) assert writer.protocol == protocol assert writer.transport == transport async def test_write_payload_eof(transport, protocol, loop): write = transport.write = mock.Mock() - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) await msg.write(b'data1') await msg.write(b'data2') @@ -52,7 +52,7 @@ async def test_write_payload_eof(transport, protocol, loop): async def test_write_payload_chunked(buf, protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b'data') await msg.write_eof() @@ -61,7 +61,7 @@ async def test_write_payload_chunked(buf, protocol, transport, loop): async def test_write_payload_chunked_multiple(buf, protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b'data1') await msg.write(b'data2') @@ -73,7 +73,7 @@ async def test_write_payload_chunked_multiple(buf, protocol, transport, loop): async def test_write_payload_length(protocol, transport, loop): write = transport.write = mock.Mock() - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.length = 2 await msg.write(b'd') await msg.write(b'ata') @@ -86,7 +86,7 @@ async def test_write_payload_length(protocol, transport, loop): async def test_write_payload_chunked_filter(protocol, transport, loop): write = transport.write = mock.Mock() - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b'da') await msg.write(b'ta') @@ -101,7 +101,7 @@ async def test_write_payload_chunked_filter_mutiple_chunks( transport, loop): write = transport.write = mock.Mock() - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_chunking() await msg.write(b'da') await msg.write(b'ta') @@ -121,7 +121,7 @@ async def test_write_payload_chunked_filter_mutiple_chunks( async def test_write_payload_deflate_compression(protocol, transport, loop): write = transport.write = mock.Mock() - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_compression('deflate') await msg.write(b'data') await msg.write_eof() @@ -137,7 +137,7 @@ async def test_write_payload_deflate_and_chunked( protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.enable_compression('deflate') msg.enable_chunking() @@ -149,7 +149,7 @@ async def test_write_payload_deflate_and_chunked( async def test_write_drain(protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg.drain = make_mocked_coro() await msg.write(b'1' * (64 * 1024 * 2), drain=False) assert not msg.drain.called @@ -162,7 +162,7 @@ async def test_write_drain(protocol, transport, loop): async def test_write_calls_callback(protocol, transport, loop): on_chunk_sent = make_mocked_coro() msg = http.StreamWriter( - protocol, transport, loop, + protocol, loop, on_chunk_sent=on_chunk_sent ) chunk = b'1' @@ -174,7 +174,7 @@ async def test_write_calls_callback(protocol, transport, loop): async def test_write_eof_calls_callback(protocol, transport, loop): on_chunk_sent = make_mocked_coro() msg = http.StreamWriter( - protocol, transport, loop, + protocol, loop, on_chunk_sent=on_chunk_sent ) chunk = b'1' @@ -184,7 +184,7 @@ async def test_write_eof_calls_callback(protocol, transport, loop): async def test_write_to_closing_transport(protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) await msg.write(b'Before closing') transport.is_closing.return_value = True @@ -194,13 +194,13 @@ async def test_write_to_closing_transport(protocol, transport, loop): async def test_drain(protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) await msg.drain() assert protocol._drain_helper.called async def test_drain_no_transport(protocol, transport, loop): - msg = http.StreamWriter(protocol, transport, loop) + msg = http.StreamWriter(protocol, loop) msg._protocol.transport = None await msg.drain() assert not protocol._drain_helper.called