Skip to content

Commit

Permalink
Add support for APIRouter prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
drnextgis committed Jul 24, 2022
1 parent 0483406 commit 0b89b2e
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 27 deletions.
6 changes: 5 additions & 1 deletion stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def customize_openapi(self) -> Optional[Dict[str, Any]]:

def add_health_check(self):
"""Add a health check."""
mgmt_router = APIRouter()
mgmt_router = APIRouter(prefix=self.app.state.router_prefix)

@mgmt_router.get("/_mgmt/ping")
async def ping():
Expand Down Expand Up @@ -381,6 +381,10 @@ def __attrs_post_init__(self):
self.register_core()
self.app.include_router(self.router)

# keep link to the router prefix value
router_prefix = self.router.prefix
self.app.state.router_prefix = router_prefix if router_prefix else ""

# register extensions
for ext in self.extensions:
ext.register(self.app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def register(self, app: FastAPI) -> None:
Returns:
None
"""
self.router.prefix = app.state.router_prefix
self.router.add_api_route(
name="Queryables",
path="/queryables",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def register(self, app: FastAPI) -> None:
Returns:
None
"""
self.router.prefix = app.state.router_prefix
self.register_create_item()
self.register_update_item()
self.register_delete_item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def register(self, app: FastAPI) -> None:
"""
items_request_model = create_request_model("Items", base_model=Items)

router = APIRouter()
router = APIRouter(prefix=app.state.router_prefix)
router.add_api_route(
name="Bulk Create Item",
path="/collections/{collection_id}/bulk_items",
Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/pgstac/stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from stac_fastapi.pgstac.utils import filter_fields
from stac_fastapi.types.core import AsyncBaseCoreClient
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection

NumType = Union[float, int]
Expand All @@ -35,7 +36,7 @@ class CoreCrudClient(AsyncBaseCoreClient):
async def all_collections(self, **kwargs) -> Collections:
"""Read all collections from the database."""
request: Request = kwargs["request"]
base_url = str(request.base_url)
base_url = get_base_url(request)
pool = request.app.state.readpool

async with pool.acquire() as conn:
Expand Down
4 changes: 3 additions & 1 deletion stac_fastapi/pgstac/stac_fastapi/pgstac/models/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

from stac_fastapi.types.requests import get_base_url

# These can be inferred from the item/collection so they aren't included in the database
# Instead they are dynamically generated when querying the database using the classes defined below
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]
Expand Down Expand Up @@ -45,7 +47,7 @@ class BaseLinks:
@property
def base_url(self):
"""Get the base url."""
return str(self.request.base_url)
return get_base_url(self.request)

@property
def url(self):
Expand Down
16 changes: 12 additions & 4 deletions stac_fastapi/pgstac/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,24 @@ async def test_api_headers(app_client):
assert resp.status_code == 200


async def test_core_router(api_client):
core_routes = set(STAC_CORE_ROUTES)
async def test_core_router(api_client, app):
core_routes = set()
for core_route in STAC_CORE_ROUTES:
method, path = core_route.split(" ")
core_routes.add("{} {}".format(method, app.state.router_prefix + path))

api_routes = set(
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
)
assert not core_routes - api_routes


async def test_transactions_router(api_client):
transaction_routes = set(STAC_TRANSACTION_ROUTES)
async def test_transactions_router(api_client, app):
transaction_routes = set()
for transaction_route in STAC_TRANSACTION_ROUTES:
method, path = transaction_route.split(" ")
transaction_routes.add("{} {}".format(method, app.state.router_prefix + path))

api_routes = set(
[f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes]
)
Expand Down
33 changes: 29 additions & 4 deletions stac_fastapi/pgstac/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os
import time
from typing import Callable, Dict
from urllib.parse import urljoin

import asyncpg
import pytest
from fastapi import APIRouter
from fastapi.responses import ORJSONResponse
from httpx import AsyncClient
from pypgstac.db import PgstacDB
Expand Down Expand Up @@ -107,9 +109,26 @@ async def pgstac(pg):


# Run all the tests that use the api_client in both db hydrate and api hydrate mode
@pytest.fixture(params=[settings, pgstac_api_hydrate_settings], scope="session")
@pytest.fixture(
params=[
(settings, ""),
(settings, "/router_prefix"),
(pgstac_api_hydrate_settings, ""),
(pgstac_api_hydrate_settings, "/router_prefix"),
],
scope="session",
)
def api_client(request, pg):
print("creating client with settings, hydrate:", request.param.use_api_hydrate)
api_settings, prefix = request.param

api_settings.openapi_url = prefix + api_settings.openapi_url
api_settings.docs_url = prefix + api_settings.docs_url

print(
"creating client with settings, hydrate: {}, router prefix: '{}'".format(
api_settings.use_api_hydrate, prefix
)
)

extensions = [
TransactionExtension(client=TransactionsClient(), settings=settings),
Expand All @@ -122,12 +141,13 @@ def api_client(request, pg):
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
settings=request.param,
settings=api_settings,
extensions=extensions,
client=CoreCrudClient(post_request_model=post_request_model),
search_get_request_model=create_get_request_model(extensions),
search_post_request_model=post_request_model,
response_class=ORJSONResponse,
router=APIRouter(prefix=prefix),
)

return api
Expand All @@ -150,7 +170,12 @@ async def app(api_client):
@pytest.fixture(scope="function")
async def app_client(app):
print("creating app_client")
async with AsyncClient(app=app, base_url="http://test") as c:

base_url = "http://test"
if app.state.router_prefix != "":
base_url = urljoin(base_url, app.state.router_prefix)

async with AsyncClient(app=app, base_url=base_url) as c:
yield c


Expand Down
10 changes: 5 additions & 5 deletions stac_fastapi/pgstac/tests/resources/test_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,25 @@ def test_landing_page_health(response):

@pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests)
async def test_landing_page_links(
response_json: Dict, app_client, rel_type, expected_media_type, expected_path
response_json: Dict, app_client, app, rel_type, expected_media_type, expected_path
):
link = get_link(response_json, rel_type)

assert link is not None, f"Missing {rel_type} link in landing page"
assert link.get("type") == expected_media_type

link_path = urllib.parse.urlsplit(link.get("href")).path
assert link_path == expected_path
assert link_path == app.state.router_prefix + expected_path

resp = await app_client.get(link_path)
resp = await app_client.get(link_path.rsplit("/", 1)[-1])
assert resp.status_code == 200


# This endpoint currently returns a 404 for empty result sets, but testing for this response
# code here seems meaningless since it would be the same as if the endpoint did not exist. Once
# https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the
# parameterized tests above.
def test_search_link(response_json: Dict):
def test_search_link(response_json: Dict, app):
for search_link in [
get_link(response_json, "search", "GET"),
get_link(response_json, "search", "POST"),
Expand All @@ -73,4 +73,4 @@ def test_search_link(response_json: Dict):
assert search_link.get("type") == "application/geo+json"

search_path = urllib.parse.urlsplit(search_link.get("href")).path
assert search_path == "/search"
assert search_path == app.state.router_prefix + "/search"
7 changes: 5 additions & 2 deletions stac_fastapi/pgstac/tests/resources/test_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ async def test_get_missing_item(app_client, load_test_data):
assert resp.status_code == 404


async def test_relative_link_construction():
async def test_relative_link_construction(app):
req = Request(
scope={
"type": "http",
Expand All @@ -1175,10 +1175,13 @@ async def test_relative_link_construction():
"raw_path": b"/tab/abc",
"query_string": b"",
"headers": {},
"app": app,
}
)
links = CollectionLinks(collection_id="naip", request=req)
assert links.link_items()["href"] == "http://test/stac/collections/naip/items"
assert links.link_items()["href"] == (
"http://test/stac{}/collections/naip/items".format(app.state.router_prefix)
)


async def test_search_bbox_errors(app_client):
Expand Down
23 changes: 15 additions & 8 deletions stac_fastapi/types/stac_fastapi/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from stac_fastapi.types import stac as stac_types
from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.search import BaseSearchPostRequest
from stac_fastapi.types.stac import Conformance

Expand Down Expand Up @@ -349,12 +350,10 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
API landing page, serving as an entry point to the API.
"""
request: Request = kwargs["request"]
base_url = str(request.base_url)
base_url = get_base_url(request)
extension_schemas = [
schema.schema_href for schema in self.extensions if schema.schema_href
]
request: Request = kwargs["request"]
base_url = str(request.base_url)
landing_page = self._landing_page(
base_url=base_url,
conformance_classes=self.conformance_classes(),
Expand All @@ -379,7 +378,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
"rel": "service-desc",
"type": "application/vnd.oai.openapi+json;version=3.0",
"title": "OpenAPI service description",
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
"href": urljoin(
str(request.base_url), request.app.openapi_url.lstrip("/")
),
}
)

Expand All @@ -389,7 +390,9 @@ def landing_page(self, **kwargs) -> stac_types.LandingPage:
"rel": "service-doc",
"type": "text/html",
"title": "OpenAPI service documentation",
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
"href": urljoin(
str(request.base_url), request.app.docs_url.lstrip("/")
),
}
)

Expand Down Expand Up @@ -540,7 +543,7 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
API landing page, serving as an entry point to the API.
"""
request: Request = kwargs["request"]
base_url = str(request.base_url)
base_url = get_base_url(request)
extension_schemas = [
schema.schema_href for schema in self.extensions if schema.schema_href
]
Expand All @@ -566,7 +569,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
"rel": "service-desc",
"type": "application/vnd.oai.openapi+json;version=3.0",
"title": "OpenAPI service description",
"href": urljoin(base_url, request.app.openapi_url.lstrip("/")),
"href": urljoin(
str(request.base_url), request.app.openapi_url.lstrip("/")
),
}
)

Expand All @@ -576,7 +581,9 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage:
"rel": "service-doc",
"type": "text/html",
"title": "OpenAPI service documentation",
"href": urljoin(base_url, request.app.docs_url.lstrip("/")),
"href": urljoin(
str(request.base_url), request.app.docs_url.lstrip("/")
),
}
)

Expand Down
14 changes: 14 additions & 0 deletions stac_fastapi/types/stac_fastapi/types/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""requests helpers."""

from starlette.requests import Request


def get_base_url(request: Request) -> str:
"""Get base URL with respect of APIRouter prefix."""
app = request.app
if not app.state.router_prefix:
return str(request.base_url)
else:
return "{}{}/".format(
str(request.base_url), app.state.router_prefix.lstrip("/")
)

0 comments on commit 0b89b2e

Please sign in to comment.