Skip to content

Commit

Permalink
Merge branch 'master' into router_prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
geospatial-jeff authored Aug 1, 2022
2 parents 7f12a57 + b7580fe commit 2a04d38
Show file tree
Hide file tree
Showing 13 changed files with 616 additions and 7 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,7 @@ docs/api/*
.envrc

# Virtualenv
venv
venv

# IDE
.vscode
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Add STAC API - Collections conformance class. ([383](https://github.com/stac-utils/stac-fastapi/pull/383))
* Bulk item inserts for pgstac implementation. ([411](https://github.com/stac-utils/stac-fastapi/pull/411))
* Add APIRouter prefix support for pgstac implementation. ([429](https://github.com/stac-utils/stac-fastapi/pull/429))
* Respect `Forwarded` or `X-Forwarded-*` request headers when building links to better accommodate load balancers and proxies.

### Changed

Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ test-sqlalchemy: run-joplin-sqlalchemy
test-pgstac:
$(run_pgstac) /bin/bash -c 'export && ./scripts/wait-for-it.sh database:5432 && cd /app/stac_fastapi/pgstac/tests/ && pytest -vvv'

.PHONY: test-api
test-api:
$(run_sqlalchemy) /bin/bash -c 'cd /app/stac_fastapi/api && pytest -svvv'

.PHONY: run-database
run-database:
docker-compose run --rm database
Expand Down
5 changes: 4 additions & 1 deletion stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from starlette.responses import JSONResponse, Response

from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
Expand Down Expand Up @@ -91,7 +92,9 @@ class StacApi:
)
pagination_extension = attr.ib(default=TokenPaginationExtension)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware]))
middlewares: List = attr.ib(
default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware])
)
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])

def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]:
Expand Down
97 changes: 96 additions & 1 deletion stac_fastapi/api/stac_fastapi/api/middleware.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""api middleware."""

from typing import Callable
import re
from http.client import HTTP_PORT, HTTPS_PORT
from typing import Callable, List, Tuple

from fastapi import APIRouter, FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.routing import Match
from starlette.types import ASGIApp, Receive, Scope, Send


def router_middleware(app: FastAPI, router: APIRouter):
Expand All @@ -29,3 +32,95 @@ async def _middleware(request: Request, call_next):
return func

return deco


class ProxyHeaderMiddleware:
"""
Account for forwarding headers when deriving base URL.
Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing.
Default to what can be derived from the URL if no headers provided.
Middleware updates the host header that is interpreted by starlette when deriving Request.base_url.
"""

def __init__(self, app: ASGIApp):
"""Create proxy header middleware."""
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Call from stac-fastapi framework."""
if scope["type"] == "http":
proto, domain, port = self._get_forwarded_url_parts(scope)
scope["scheme"] = proto
if domain is not None:
port_suffix = ""
if port is not None:
if (proto == "http" and port != HTTP_PORT) or (
proto == "https" and port != HTTPS_PORT
):
port_suffix = f":{port}"
scope["headers"] = self._replace_header_value_by_name(
scope,
"host",
f"{domain}{port_suffix}",
)
await self.app(scope, receive, send)

def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:
print(scope)
proto = scope.get("scheme", "http")
header_host = self._get_header_value_by_name(scope, "host")
if header_host is None:
domain, port = scope.get("server")
else:
header_host_parts = header_host.split(":")
if len(header_host_parts) == 2:
domain, port = header_host_parts
else:
domain = header_host_parts[0]
port = None
forwarded = self._get_header_value_by_name(scope, "forwarded")
if forwarded is not None:
parts = forwarded.split(";")
for part in parts:
if len(part) > 0 and re.search("=", part):
key, value = part.split("=")
if key == "proto":
proto = value
elif key == "host":
host_parts = value.split(":")
domain = host_parts[0]
try:
port = int(host_parts[1]) if len(host_parts) == 2 else None
except ValueError:
# ignore ports that are not valid integers
pass
else:
proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)
try:
port = int(port_str) if port_str is not None else None
except ValueError:
# ignore ports that are not valid integers
pass

return (proto, domain, port)

def _get_header_value_by_name(
self, scope: Scope, header_name: str, default_value: str = None
) -> str:
headers = scope["headers"]
candidates = [
value.decode() for key, value in headers if key.decode() == header_name
]
return candidates[0] if len(candidates) == 1 else default_value

@staticmethod
def _replace_header_value_by_name(
scope: Scope, header_name: str, new_value: str
) -> List[Tuple[str]]:
return [
(name, value)
for name, value in scope["headers"]
if name.decode() != header_name
] + [(str.encode(header_name), str.encode(new_value))]
140 changes: 140 additions & 0 deletions stac_fastapi/api/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pytest
from starlette.applications import Starlette

from stac_fastapi.api.middleware import ProxyHeaderMiddleware


@pytest.fixture
def proxy_header_middleware() -> ProxyHeaderMiddleware:
app = Starlette()
return ProxyHeaderMiddleware(app)


@pytest.mark.parametrize(
"headers,key,expected",
[
([(b"host", b"testserver")], "host", "testserver"),
([(b"host", b"testserver")], "user-agent", None),
(
[(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")],
"accept-encoding",
"gzip, deflate, br",
),
],
)
def test_get_header_value_by_name(
proxy_header_middleware: ProxyHeaderMiddleware, headers, key, expected
):
scope = {"headers": headers}
actual = proxy_header_middleware._get_header_value_by_name(scope, key)
assert actual == expected


@pytest.mark.parametrize(
"headers,key,value",
[
([(b"host", b"testserver")], "host", "another-server"),
([(b"host", b"testserver")], "user-agent", "agent"),
(
[(b"host", b"testserver"), (b"accept-encoding", b"gzip, deflate, br")],
"accept-encoding",
"deflate",
),
],
)
def test_replace_header_value_by_name(
proxy_header_middleware: ProxyHeaderMiddleware, headers, key, value
):
scope = {"headers": headers}
updated_headers = proxy_header_middleware._replace_header_value_by_name(
scope, key, value
)

header_value = proxy_header_middleware._get_header_value_by_name(
{"headers": updated_headers}, key
)
assert header_value == value


@pytest.mark.parametrize(
"scope,expected",
[
(
{"scheme": "https", "server": ["testserver", 80], "headers": []},
("https", "testserver", 80),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"host", b"testserver:81")],
},
("http", "testserver", 81),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"host", b"testserver")],
},
("http", "testserver", None),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"forwarded", b"proto=https;host=test:1234")],
},
("https", "test", 1234),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"forwarded", b"proto=https;host=test:not-an-integer")],
},
("https", "test", 80),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"x-forwarded-proto", b"https")],
},
("https", "testserver", 80),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"x-forwarded-port", b"1111")],
},
("http", "testserver", 1111),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [(b"x-forwarded-port", b"not-an-integer")],
},
("http", "testserver", 80),
),
(
{
"scheme": "http",
"server": ["testserver", 80],
"headers": [
(b"forwarded", b"proto=https;host=test:1234"),
(b"x-forwarded-port", b"1111"),
(b"x-forwarded-proto", b"https"),
],
},
("https", "test", 1234),
),
],
)
def test_get_forwarded_url_parts(
proxy_header_middleware: ProxyHeaderMiddleware, scope, expected
):
actual = proxy_header_middleware._get_forwarded_url_parts(scope)
assert actual == expected
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ class FilterConformanceClasses(str, Enum):
See https://github.com/radiantearth/stac-api-spec/tree/v1.0.0-rc.1/fragments/filter
"""

FILTER = "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter:filter"
ITEM_SEARCH_FILTER = (
"https://api.stacspec.org/v1.0.0-rc.1/item-search#filter:item-search-filter"
FILTER = "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter"
FEATURES_FILTER = (
"http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/features-filter"
)
ITEM_SEARCH_FILTER = "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter"
CQL_TEXT = "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter:cql-text"
CQL_JSON = "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter:cql-json"
BASIC_CQL = "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter:basic-cql"
Expand Down Expand Up @@ -70,6 +71,7 @@ class FilterExtension(ApiExtension):
conformance_classes: List[str] = attr.ib(
default=[
FilterConformanceClasses.FILTER,
FilterConformanceClasses.FEATURES_FILTER,
FilterConformanceClasses.ITEM_SEARCH_FILTER,
FilterConformanceClasses.BASIC_CQL,
FilterConformanceClasses.CQL_TEXT,
Expand Down
Loading

0 comments on commit 2a04d38

Please sign in to comment.