diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 3f67dffb..81aa7be1 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -7,7 +7,7 @@ from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import AsyncEvent, AsyncLock +from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface @@ -193,9 +193,10 @@ async def handle_async_request(self, request: Request) -> Response: pool_request = AsyncPoolRequest(request) try: while True: - async with self._pool_lock: - self._requests.append(pool_request) - closing = self._assign_requests_to_connections() + with AsyncShieldCancellation(): + async with self._pool_lock: + self._requests.append(pool_request) + closing = self._assign_requests_to_connections() await self._close_connections(closing) connection = await pool_request.wait_for_connection(timeout=timeout) @@ -209,9 +210,10 @@ async def handle_async_request(self, request: Request) -> Response: break except BaseException as exc: - async with self._pool_lock: - self._requests.remove(pool_request) - closing = self._assign_requests_to_connections() + with AsyncShieldCancellation(): + async with self._pool_lock: + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() await self._close_connections(closing) raise exc from None @@ -225,17 +227,6 @@ async def handle_async_request(self, request: Request) -> Response: extensions=response.extensions, ) - async def _request_closed(self, request: AsyncPoolRequest) -> None: - """ - Once a request completes we remove it from the pool, - and determine if we can now assign any queued requests - to a connection. - """ - async with self._pool_lock: - self._requests.remove(request) - closing = self._assign_requests_to_connections() - await self._close_connections(closing) - def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: """ Manage the state of the connection pool, assigning incoming @@ -248,12 +239,20 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: """ closing_connections = [] - # Close any expired connections. for connection in list(self._connections): - if connection.has_expired(): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) + elif ( + connection.is_idle() and len(self._connections) > self._max_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) # Assign queued requests to connections. queued_requests = [ @@ -266,6 +265,9 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: for connection in self._connections if connection.can_handle_request(origin) and connection.is_available() ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] if avilable_connections: # log: "reusing existing connection" connection = avilable_connections[0] @@ -275,12 +277,7 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - else: - idle_connections = [ - connection - for connection in self._connections - if connection.is_idle() - ] + elif idle_connections: # log: "closing idle connection" connection = idle_connections[0] self._connections.remove(connection) @@ -300,10 +297,9 @@ async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> N await connection.aclose() async def aclose(self) -> None: - closing = list(self._connections) - self._requests = [] + closing_connections = list(self._connections) self._connections = [] - await self._close_connections(closing) + await self._close_connections(closing_connections) async def __aenter__(self) -> "AsyncConnectionPool": # Acquiring the pool lock here ensures that we have the @@ -331,19 +327,23 @@ def __init__( self._stream = stream self._pool_request = pool_request self._pool = pool + self._closed = False + assert self._pool_request in self._pool._requests async def __aiter__(self) -> AsyncIterator[bytes]: try: async for part in self._stream: yield part except BaseException: - async with self._pool._pool_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - await self._pool._close_connections(closing) + await self.aclose() async def aclose(self) -> None: - async with self._pool._pool_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - await self._pool._close_connections(closing) + if not self._closed: + self._closed = True + with AsyncShieldCancellation(): + if hasattr(self._stream, "aclose"): + await self._stream.aclose() + async with self._pool._pool_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + await self._pool._close_connections(closing) diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 050ffdba..ce2fca42 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -7,7 +7,7 @@ from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import Event, Lock +from .._synchronization import Event, Lock, ShieldCancellation from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface @@ -193,9 +193,10 @@ def handle_request(self, request: Request) -> Response: pool_request = PoolRequest(request) try: while True: - with self._pool_lock: - self._requests.append(pool_request) - closing = self._assign_requests_to_connections() + with ShieldCancellation(): + with self._pool_lock: + self._requests.append(pool_request) + closing = self._assign_requests_to_connections() self._close_connections(closing) connection = pool_request.wait_for_connection(timeout=timeout) @@ -209,9 +210,10 @@ def handle_request(self, request: Request) -> Response: break except BaseException as exc: - with self._pool_lock: - self._requests.remove(pool_request) - closing = self._assign_requests_to_connections() + with ShieldCancellation(): + with self._pool_lock: + self._requests.remove(pool_request) + closing = self._assign_requests_to_connections() self._close_connections(closing) raise exc from None @@ -225,17 +227,6 @@ def handle_request(self, request: Request) -> Response: extensions=response.extensions, ) - def _request_closed(self, request: PoolRequest) -> None: - """ - Once a request completes we remove it from the pool, - and determine if we can now assign any queued requests - to a connection. - """ - with self._pool_lock: - self._requests.remove(request) - closing = self._assign_requests_to_connections() - self._close_connections(closing) - def _assign_requests_to_connections(self) -> List[ConnectionInterface]: """ Manage the state of the connection pool, assigning incoming @@ -248,12 +239,20 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: """ closing_connections = [] - # Close any expired connections. for connection in list(self._connections): - if connection.has_expired(): + if connection.is_closed(): + # log: "removing closed connection" + self._connections.remove(connection) + elif connection.has_expired(): # log: "closing expired connection" self._connections.remove(connection) closing_connections.append(connection) + elif ( + connection.is_idle() and len(self._connections) > self._max_connections + ): + # log: "closing idle connection" + self._connections.remove(connection) + closing_connections.append(connection) # Assign queued requests to connections. queued_requests = [ @@ -266,6 +265,9 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: for connection in self._connections if connection.can_handle_request(origin) and connection.is_available() ] + idle_connections = [ + connection for connection in self._connections if connection.is_idle() + ] if avilable_connections: # log: "reusing existing connection" connection = avilable_connections[0] @@ -275,12 +277,7 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: connection = self.create_connection(origin) self._connections.append(connection) pool_request.assign_to_connection(connection) - else: - idle_connections = [ - connection - for connection in self._connections - if connection.is_idle() - ] + elif idle_connections: # log: "closing idle connection" connection = idle_connections[0] self._connections.remove(connection) @@ -300,10 +297,9 @@ def _close_connections(self, closing: List[ConnectionInterface]) -> None: connection.close() def close(self) -> None: - closing = list(self._connections) - self._requests = [] + closing_connections = list(self._connections) self._connections = [] - self._close_connections(closing) + self._close_connections(closing_connections) def __enter__(self) -> "ConnectionPool": # Acquiring the pool lock here ensures that we have the @@ -331,19 +327,23 @@ def __init__( self._stream = stream self._pool_request = pool_request self._pool = pool + self._closed = False + assert self._pool_request in self._pool._requests def __iter__(self) -> Iterator[bytes]: try: for part in self._stream: yield part except BaseException: - with self._pool._pool_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - self._pool._close_connections(closing) + self.close() def close(self) -> None: - with self._pool._pool_lock: - self._pool._requests.remove(self._pool_request) - closing = self._pool._assign_requests_to_connections() - self._pool._close_connections(closing) + if not self._closed: + self._closed = True + with ShieldCancellation(): + if hasattr(self._stream, "close"): + self._stream.close() + with self._pool._pool_lock: + self._pool._requests.remove(self._pool_request) + closing = self._pool._assign_requests_to_connections() + self._pool._close_connections(closing) diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py index 61ee1e54..4ee6ca33 100644 --- a/tests/_async/test_connection_pool.py +++ b/tests/_async/test_connection_pool.py @@ -66,8 +66,8 @@ async def test_connection_pool_with_keepalive(): async with pool.stream("GET", "http://example.com/") as response: info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] await response.aread() @@ -75,8 +75,8 @@ async def test_connection_pool_with_keepalive(): assert response.content == b"Hello, world!" info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] @@ -205,11 +205,16 @@ async def test_connection_pool_with_http2_goaway(): http2=True, ) + def debug(*args, **kwargs): + print(*args, **kwargs) + async with httpcore.AsyncConnectionPool( network_backend=network_backend, ) as pool: # Sending an intial request, which once complete will return to the pool, IDLE. - response = await pool.request("GET", "https://example.com/") + response = await pool.request( + "GET", "https://example.com/", exensions={"trace": debug} + ) assert response.status == 200 assert response.content == b"Hello, world!" @@ -225,8 +230,8 @@ async def test_connection_pool_with_http2_goaway(): info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py index c9621c7b..9ee63456 100644 --- a/tests/_sync/test_connection_pool.py +++ b/tests/_sync/test_connection_pool.py @@ -66,8 +66,8 @@ def test_connection_pool_with_keepalive(): with pool.stream("GET", "http://example.com/") as response: info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] response.read() @@ -75,8 +75,8 @@ def test_connection_pool_with_keepalive(): assert response.content == b"Hello, world!" info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ] @@ -205,11 +205,16 @@ def test_connection_pool_with_http2_goaway(): http2=True, ) + def debug(*args, **kwargs): + print(*args, **kwargs) + with httpcore.ConnectionPool( network_backend=network_backend, ) as pool: # Sending an intial request, which once complete will return to the pool, IDLE. - response = pool.request("GET", "https://example.com/") + response = pool.request( + "GET", "https://example.com/", exensions={"trace": debug} + ) assert response.status == 200 assert response.content == b"Hello, world!" @@ -225,8 +230,8 @@ def test_connection_pool_with_http2_goaway(): info = [repr(c) for c in pool.connections] assert info == [ - "", "", + "", ]