Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ProxyHeadersMiddleware #1611

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@ async def app(
# trusted proxy list
(["127.0.0.1", "10.0.0.1"], "Remote: https://1.2.3.4:0"),
("127.0.0.1, 10.0.0.1", "Remote: https://1.2.3.4:0"),
# trusted proxy network
# https://github.com/encode/uvicorn/issues/1068#issuecomment-1004813267
("127.0.0.0/24, 10.0.0.1", "Remote: https://1.2.3.4:0"),
# request from untrusted proxy
("192.168.0.1", "Remote: http://127.0.0.1:123"),
# request from untrusted proxy network
("192.168.0.0/16", "Remote: http://127.0.0.1:123"),
# request from client running on proxy server itself
# https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
(["127.0.0.1", "1.2.3.4"], "Remote: https://1.2.3.4:0"),
],
)
async def test_proxy_headers_trusted_hosts(
Expand Down Expand Up @@ -68,6 +76,8 @@ async def test_proxy_headers_trusted_hosts(
),
# should set first untrusted as remote address
(["192.168.0.2", "127.0.0.1"], "Remote: https://10.0.2.1:0"),
# Mixed literals and networks
(["127.0.0.1", "10.0.0.0/8", "192.168.0.2"], "Remote: https://1.2.3.4:0"),
],
)
async def test_proxy_headers_multiple_proxies(
Expand Down Expand Up @@ -103,3 +113,24 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Remote: https://1.2.3.4:0"


@pytest.mark.anyio
async def test_proxy_headers_empty_x_forwarded_for() -> None:
# fallback to the default behavior if x-forwarded-for is an empty list
# https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts="*")
transport = httpx.ASGITransport(app=app_with_middleware, client=("1.2.3.4", 8080))
async with httpx.AsyncClient(
transport=transport, base_url="http://testserver"
) as client:
headers = httpx.Headers(
{
"X-Forwarded-Proto": "https",
"X-Forwarded-For": "",
},
encoding="latin-1",
)
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Remote: https://1.2.3.4:8080"
85 changes: 60 additions & 25 deletions uvicorn/middleware/proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
import ipaddress
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast

if TYPE_CHECKING:
from asgiref.typing import (
Expand All @@ -21,30 +22,62 @@
)


def _parse_raw_hosts(value: str) -> List[str]:
return [item.strip() for item in value.split(",")]


class _TrustedHosts:
def __init__(self, trusted_hosts: Union[List[str], str]) -> None:
self.trusted_networks: Set[ipaddress.IPv4Network] = set()
self.trusted_literals: Set[str] = set()
self.always_trust = trusted_hosts == "*"

if not self.always_trust:
if isinstance(trusted_hosts, str):
trusted_hosts = _parse_raw_hosts(trusted_hosts)
for host in trusted_hosts:
try:
# Try parsing the trusted host as an IPv4Network
# to allow checking a whole range.
# https://github.com/encode/uvicorn/issues/1068
self.trusted_networks.add(ipaddress.IPv4Network(host))
Copy link
Contributor

@nhairs nhairs Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for using IPv4Network instead of ipaddress.ip_network which is version agnostic?

except ValueError:
self.trusted_literals.add(host)
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test case for this?


def __contains__(self, item: Optional[str]) -> bool:
if self.always_trust:
return True

try:
ip = ipaddress.IPv4Address(item)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're using IPv4Address instead of ipaddress.ip_address which is version agnostic?

return any(ip in net for net in self.trusted_networks)
except ValueError:
return item in self.trusted_literals

def get_trusted_client_host(self, x_forwarded_for: str) -> Optional[str]:
x_forwarded_for_hosts = _parse_raw_hosts(x_forwarded_for)
if self.always_trust:
return x_forwarded_for_hosts[0]

host = None
for host in reversed(x_forwarded_for_hosts):
if host not in self:
return host
# The request came from a client on the proxy itself. Trust it.
# See https://github.com/encode/uvicorn/issues/1068#issuecomment-855371576
if host in self:
return x_forwarded_for_hosts[0]
return host
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be reached? What's the scenario?



class ProxyHeadersMiddleware:
def __init__(
self,
app: "ASGI3Application",
trusted_hosts: Union[List[str], str] = "127.0.0.1",
) -> None:
self.app = app
if isinstance(trusted_hosts, str):
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
else:
self.trusted_hosts = set(trusted_hosts)
self.always_trust = "*" in self.trusted_hosts

def get_trusted_client_host(
self, x_forwarded_for_hosts: List[str]
) -> Optional[str]:
if self.always_trust:
return x_forwarded_for_hosts[0]

for host in reversed(x_forwarded_for_hosts):
if host not in self.trusted_hosts:
return host

return None
self.trusted_hosts = _TrustedHosts(trusted_hosts)

async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
Expand All @@ -54,7 +87,7 @@ async def __call__(
client_addr: Optional[Tuple[str, int]] = scope.get("client")
client_host = client_addr[0] if client_addr else None

if self.always_trust or client_host in self.trusted_hosts:
if client_host in self.trusted_hosts:
headers = dict(scope["headers"])

if b"x-forwarded-proto" in headers:
Expand All @@ -68,11 +101,13 @@ async def __call__(
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)

# Host is None or an empty string
# if the x-forwarded-for header is empty.
# See https://github.com/encode/uvicorn/issues/1068
if host:
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]

return await self.app(scope, receive, send)