diff --git a/CHANGES/8632.bugfix.rst b/CHANGES/8632.bugfix.rst new file mode 100644 index 00000000000..c6da81d7ab3 --- /dev/null +++ b/CHANGES/8632.bugfix.rst @@ -0,0 +1 @@ +Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 1d4ccc0814a..3d1045f355a 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -75,6 +75,7 @@ ) from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse from .connector import ( + HTTP_AND_EMPTY_SCHEMA_SET, BaseConnector as BaseConnector, NamedPipeConnector as NamedPipeConnector, TCPConnector as TCPConnector, @@ -209,9 +210,6 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) -HTTP_SCHEMA_SET = frozenset({"http", "https", ""}) -WS_SCHEMA_SET = frozenset({"ws", "wss"}) -ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET _RetType = TypeVar("_RetType") _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -517,7 +515,8 @@ async def _request( except ValueError as e: raise InvalidUrlClientError(str_or_url) from e - if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET: + assert self._connector is not None + if url.scheme not in self._connector.allowed_protocol_schema_set: raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) @@ -655,7 +654,6 @@ async def _request( real_timeout.connect, ceil_threshold=real_timeout.ceil_threshold, ): - assert self._connector is not None conn = await self._connector.connect( req, traces=traces, timeout=real_timeout ) @@ -752,7 +750,7 @@ async def _request( ) from e scheme = parsed_redirect_url.scheme - if scheme not in HTTP_SCHEMA_SET: + if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: resp.close() raise NonHttpUrlRedirectClientError(r_url) elif not scheme: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 2e07395aece..d4691b10e6e 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -63,6 +63,14 @@ SSLContext = object # type: ignore[misc,assignment] +EMPTY_SCHEMA_SET = frozenset({""}) +HTTP_SCHEMA_SET = frozenset({"http", "https"}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) + +HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET +HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET + + __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") @@ -211,6 +219,8 @@ class BaseConnector: # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET + def __init__( self, *, @@ -760,6 +770,8 @@ class TCPConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + def __init__( self, *, @@ -1458,6 +1470,8 @@ class UnixConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) + def __init__( self, path: str, @@ -1514,6 +1528,8 @@ class NamedPipeConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) + def __init__( self, path: str, diff --git a/tests/test_client_session.py b/tests/test_client_session.py index a522094a287..051c0aeba24 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -4,7 +4,7 @@ import io import json from http.cookies import SimpleCookie -from typing import Any, List +from typing import Any, Awaitable, Callable, List from unittest import mock from uuid import uuid4 @@ -16,10 +16,12 @@ import aiohttp from aiohttp import client, hdrs, web from aiohttp.client import ClientSession +from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest -from aiohttp.connector import BaseConnector, TCPConnector +from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.helpers import DEBUG from aiohttp.test_utils import make_mocked_coro +from aiohttp.tracing import Trace @pytest.fixture @@ -487,15 +489,17 @@ async def test_ws_connect_allowed_protocols( hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - resp.url = URL(f"{protocol}://example.com") + resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) req.send = mock.AsyncMock(return_value=resp) + # BaseConnector allows all high level protocols by default + connector = BaseConnector() - session = await create_session(request_class=req_factory) + session = await create_session(connector=connector, request_class=req_factory) connections = [] original_connect = session._connector.connect @@ -515,7 +519,68 @@ async def create_connection(req, traces, timeout): "aiohttp.client.os" ) as m_os: m_os.urandom.return_value = key_data - await session.ws_connect(f"{protocol}://example.com") + await session.ws_connect(f"{protocol}://example") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + c.__del__() + + await session.close() + + +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"]) +async def test_ws_connect_unix_socket_allowed_protocols( + create_session: Callable[..., Awaitable[ClientSession]], + create_mocked_conn: Callable[[], ResponseHandler], + protocol: str, + ws_key: bytes, + key_data: bytes, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + # UnixConnector allows all high level protocols by default and unix sockets + session = await create_session( + connector=UnixConnector(path=""), request_class=req_factory + ) + + connections = [] + assert session._connector is not None + original_connect = session._connector.connect + + async def connect( + req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + ) -> Connection: + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection( + req: object, traces: object, timeout: object + ) -> ResponseHandler: + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example") # normally called during garbage collection. triggers an exception # if the connection wasn't already closed diff --git a/tests/test_connector.py b/tests/test_connector.py index 2065adf7414..d146fb4ee51 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1481,7 +1481,19 @@ async def test_tcp_connector_ctor() -> None: assert conn.family == 0 -async def test_tcp_connector_ctor_fingerprint_valid(loop) -> None: +async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None: + conn = aiohttp.TCPConnector() + assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"} + + +async def test_invalid_ssl_param() -> None: + with pytest.raises(TypeError): + aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type] + + +async def test_tcp_connector_ctor_fingerprint_valid( + loop: asyncio.AbstractEventLoop, +) -> None: valid = aiohttp.Fingerprint(hashlib.sha256(b"foo").digest()) conn = aiohttp.TCPConnector(ssl=valid, loop=loop) assert conn._ssl is valid @@ -1639,8 +1651,23 @@ async def test_ctor_with_default_loop(loop) -> None: assert loop is conn._loop -async def test_connect_with_limit(loop, key) -> None: - proto = mock.Mock() +async def test_base_connector_allows_high_level_protocols( + loop: asyncio.AbstractEventLoop, +) -> None: + conn = aiohttp.BaseConnector() + assert conn.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + } + + +async def test_connect_with_limit( + loop: asyncio.AbstractEventLoop, key: ConnectionKey +) -> None: + proto = create_mocked_conn(loop) proto.is_connected.return_value = True req = ClientRequest( @@ -2412,6 +2439,14 @@ async def handler(request): connector = aiohttp.UnixConnector(unix_sockname) assert unix_sockname == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "unix", + } session = client.ClientSession(connector=connector) r = await session.get(url) @@ -2437,6 +2472,14 @@ async def handler(request): connector = aiohttp.NamedPipeConnector(pipe_name) assert pipe_name == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "npipe", + } session = client.ClientSession(connector=connector) r = await session.get(url)