diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 648407248..a0170672a 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -32,8 +32,16 @@ async def app( # trusted proxy list (["127.0.0.1", "10.0.0.1"], "Remote: https://1.2.3.4:0"), ("127.0.0.1, 10.0.0.1", "Remote: https://1.2.3.4:0"), + # trusted proxy network + # https://github.com/encode/uvicorn/issues/1068#issuecomment-1004813267 + ("127.0.0.0/24, 10.0.0.1", "Remote: https://1.2.3.4:0"), # request from untrusted proxy ("192.168.0.1", "Remote: http://127.0.0.1:123"), + # request from untrusted proxy network + ("192.168.0.0/16", "Remote: http://127.0.0.1:123"), + # request from client running on proxy server itself + # https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576 + (["127.0.0.1", "1.2.3.4"], "Remote: https://1.2.3.4:0"), ], ) async def test_proxy_headers_trusted_hosts( @@ -68,6 +76,8 @@ async def test_proxy_headers_trusted_hosts( ), # should set first untrusted as remote address (["192.168.0.2", "127.0.0.1"], "Remote: https://10.0.2.1:0"), + # Mixed literals and networks + (["127.0.0.1", "10.0.0.0/8", "192.168.0.2"], "Remote: https://1.2.3.4:0"), ], ) async def test_proxy_headers_multiple_proxies( @@ -103,3 +113,24 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None: response = await client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Remote: https://1.2.3.4:0" + + +@pytest.mark.anyio +async def test_proxy_headers_empty_x_forwarded_for() -> None: + # fallback to the default behavior if x-forwarded-for is an empty list + # https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576 + app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts="*") + transport = httpx.ASGITransport(app=app_with_middleware, client=("1.2.3.4", 8080)) + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: + headers = httpx.Headers( + { + "X-Forwarded-Proto": "https", + "X-Forwarded-For": "", + }, + encoding="latin-1", + ) + response = await client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Remote: https://1.2.3.4:8080" diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 4b62cf209..91ada58c1 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -8,7 +8,8 @@ https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +import ipaddress +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast if TYPE_CHECKING: from asgiref.typing import ( @@ -21,6 +22,54 @@ ) +def _parse_raw_hosts(value: str) -> List[str]: + return [item.strip() for item in value.split(",")] + + +class _TrustedHosts: + def __init__(self, trusted_hosts: Union[List[str], str]) -> None: + self.trusted_networks: Set[ipaddress.IPv4Network] = set() + self.trusted_literals: Set[str] = set() + self.always_trust = trusted_hosts == "*" + + if not self.always_trust: + if isinstance(trusted_hosts, str): + trusted_hosts = _parse_raw_hosts(trusted_hosts) + for host in trusted_hosts: + try: + # Try parsing the trusted host as an IPv4Network + # to allow checking a whole range. + # https://github.com/encode/uvicorn/issues/1068 + self.trusted_networks.add(ipaddress.IPv4Network(host)) + except ValueError: + self.trusted_literals.add(host) + + def __contains__(self, item: Optional[str]) -> bool: + if self.always_trust: + return True + + try: + ip = ipaddress.IPv4Address(item) + return any(ip in net for net in self.trusted_networks) + except ValueError: + return item in self.trusted_literals + + def get_trusted_client_host(self, x_forwarded_for: str) -> Optional[str]: + x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for) + if self.always_trust: + return x_forwarded_for_hosts[0] + + host = None + for host in reversed(x_forwarded_for_hosts): + if host not in self: + return host + # The request came from a client on the proxy itself. Trust it. + # See https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576 + if host in self: + return x_forwarded_for_hosts[0] + return host + + class ProxyHeadersMiddleware: def __init__( self, @@ -28,23 +77,7 @@ def __init__( trusted_hosts: Union[List[str], str] = "127.0.0.1", ) -> None: self.app = app - if isinstance(trusted_hosts, str): - self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")} - else: - self.trusted_hosts = set(trusted_hosts) - self.always_trust = "*" in self.trusted_hosts - - def get_trusted_client_host( - self, x_forwarded_for_hosts: List[str] - ) -> Optional[str]: - if self.always_trust: - return x_forwarded_for_hosts[0] - - for host in reversed(x_forwarded_for_hosts): - if host not in self.trusted_hosts: - return host - - return None + self.trusted_hosts = _TrustedHosts(trusted_hosts) async def __call__( self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable" @@ -54,7 +87,7 @@ async def __call__( client_addr: Optional[Tuple[str, int]] = scope.get("client") client_host = client_addr[0] if client_addr else None - if self.always_trust or client_host in self.trusted_hosts: + if client_host in self.trusted_hosts: headers = dict(scope["headers"]) if b"x-forwarded-proto" in headers: @@ -68,11 +101,13 @@ async def __call__( # X-Forwarded-For header. We've lost the connecting client's port # information by now, so only include the host. x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") - x_forwarded_for_hosts = [ - item.strip() for item in x_forwarded_for.split(",") - ] - host = self.get_trusted_client_host(x_forwarded_for_hosts) - port = 0 - scope["client"] = (host, port) # type: ignore[arg-type] + host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for) + + # Host is None or an empty string + # if the x-forwarded-for header is empty. + # See https://github.com/encode/uvicorn/issues/1068 + if host: + port = 0 + scope["client"] = (host, port) # type: ignore[arg-type] return await self.app(scope, receive, send)