From c844034788792cddd65f4a97912e28974d9af080 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerman Date: Tue, 30 Nov 2021 11:08:57 -0500 Subject: [PATCH] Set content-type to geojson for search results (#288) * Set content-type to geojson for search results * Update changelog * Test for ORJSON availability Co-authored-by: Rob Emanuele --- CHANGES.md | 1 + stac_fastapi/api/stac_fastapi/api/app.py | 46 ++++++++++++------- stac_fastapi/api/stac_fastapi/api/models.py | 20 ++++++++ stac_fastapi/pgstac/tests/api/test_api.py | 13 ++++++ stac_fastapi/sqlalchemy/tests/api/test_api.py | 11 +++++ 5 files changed, 74 insertions(+), 17 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index c1c8cd2b4..02e616a60 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,7 @@ ### Fixed +* Content-type response headers for the /search endpoint now reflect the geojson response expected in the STAC api spec ([#220](https://github.com/stac-utils/stac-fastapi/issues/220) * The minimum `limit` value for searches is now 1 ([#296](https://github.com/stac-utils/stac-fastapi/pull/296)) * Links stored with Collections and Items (e.g. license links) are now returned with those STAC objects ([#282](https://github.com/stac-utils/stac-fastapi/pull/282)) * Content-type response headers for the /api endpoint now reflect those expected in the STAC api spec ([#287](https://github.com/stac-utils/stac-fastapi/pull/287)) diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 65ef21e9e..fef83ef76 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -17,6 +17,7 @@ APIRequest, CollectionUri, EmptyRequest, + GeoJSONResponse, ItemCollectionUri, ItemUri, SearchGetRequest, @@ -96,17 +97,16 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension] return None def _create_endpoint( - self, func: Callable, request_type: Union[Type[APIRequest], Type[BaseModel]] + self, + func: Callable, + request_type: Union[Type[APIRequest], Type[BaseModel]], + resp_class: Type[Response], ) -> Callable: """Create a FastAPI endpoint.""" if isinstance(self.client, AsyncBaseCoreClient): - return create_async_endpoint( - func, request_type, response_class=self.response_class - ) + return create_async_endpoint(func, request_type, response_class=resp_class) elif isinstance(self.client, BaseCoreClient): - return create_sync_endpoint( - func, request_type, response_class=self.response_class - ) + return create_sync_endpoint(func, request_type, response_class=resp_class) raise NotImplementedError def register_landing_page(self): @@ -125,7 +125,9 @@ def register_landing_page(self): response_model_exclude_unset=False, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.landing_page, EmptyRequest), + endpoint=self._create_endpoint( + self.client.landing_page, EmptyRequest, self.response_class + ), ) def register_conformance_classes(self): @@ -144,7 +146,9 @@ def register_conformance_classes(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.conformance, EmptyRequest), + endpoint=self._create_endpoint( + self.client.conformance, EmptyRequest, self.response_class + ), ) def register_get_item(self): @@ -161,7 +165,9 @@ def register_get_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.get_item, ItemUri), + endpoint=self._create_endpoint( + self.client.get_item, ItemUri, self.response_class + ), ) def register_post_search(self): @@ -178,12 +184,12 @@ def register_post_search(self): response_model=(ItemCollection if not fields_ext else None) if self.settings.enable_response_models else None, - response_class=self.response_class, + response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], endpoint=self._create_endpoint( - self.client.post_search, search_request_model + self.client.post_search, search_request_model, GeoJSONResponse ), ) @@ -200,12 +206,12 @@ def register_get_search(self): response_model=(ItemCollection if not fields_ext else None) if self.settings.enable_response_models else None, - response_class=self.response_class, + response_class=GeoJSONResponse, response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.get_search, self.search_get_request + self.client.get_search, self.search_get_request, GeoJSONResponse ), ) @@ -225,7 +231,9 @@ def register_get_collections(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.all_collections, EmptyRequest), + endpoint=self._create_endpoint( + self.client.all_collections, EmptyRequest, self.response_class + ), ) def register_get_collection(self): @@ -242,7 +250,9 @@ def register_get_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.get_collection, CollectionUri), + endpoint=self._create_endpoint( + self.client.get_collection, CollectionUri, self.response_class + ), ) def register_get_item_collection(self): @@ -262,7 +272,9 @@ def register_get_item_collection(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.item_collection, self.item_collection_uri + self.client.item_collection, + self.item_collection_uri, + self.response_class, ), ) diff --git a/stac_fastapi/api/stac_fastapi/api/models.py b/stac_fastapi/api/stac_fastapi/api/models.py index f44ae0b38..c93fef7b4 100644 --- a/stac_fastapi/api/stac_fastapi/api/models.py +++ b/stac_fastapi/api/stac_fastapi/api/models.py @@ -1,6 +1,7 @@ """api request/response models.""" import abc +import importlib from typing import Dict, Optional, Type, Union import attr @@ -127,3 +128,22 @@ def kwargs(self) -> Dict: "fields": self.fields.split(",") if self.fields else self.fields, "sortby": self.sortby.split(",") if self.sortby else self.sortby, } + + +# Test for ORJSON and use it rather than stdlib JSON where supported +if importlib.util.find_spec("orjson") is not None: + from fastapi.responses import ORJSONResponse + + class GeoJSONResponse(ORJSONResponse): + """JSON with custom, vendor content-type.""" + + media_type = "application/geo+json" + + +else: + from starlette.responses import JSONResponse + + class GeoJSONResponse(JSONResponse): + """JSON with custom, vendor content-type.""" + + media_type = "application/geo+json" diff --git a/stac_fastapi/pgstac/tests/api/test_api.py b/stac_fastapi/pgstac/tests/api/test_api.py index c284d21cd..9d93ff73f 100644 --- a/stac_fastapi/pgstac/tests/api/test_api.py +++ b/stac_fastapi/pgstac/tests/api/test_api.py @@ -23,6 +23,19 @@ ] +@pytest.mark.asyncio +async def test_post_search_content_type(app_client): + params = {"limit": 1} + resp = await app_client.post("search", json=params) + assert resp.headers["content-type"] == "application/geo+json" + + +@pytest.mark.asyncio +async def test_get_search_content_type(app_client): + resp = await app_client.get("search") + assert resp.headers["content-type"] == "application/geo+json" + + @pytest.mark.asyncio async def test_api_headers(app_client): resp = await app_client.get("/api") diff --git a/stac_fastapi/sqlalchemy/tests/api/test_api.py b/stac_fastapi/sqlalchemy/tests/api/test_api.py index c78433e56..de1b9ecc4 100644 --- a/stac_fastapi/sqlalchemy/tests/api/test_api.py +++ b/stac_fastapi/sqlalchemy/tests/api/test_api.py @@ -23,6 +23,17 @@ ] +def test_post_search_content_type(app_client): + params = {"limit": 1} + resp = app_client.post("search", json=params) + assert resp.headers["content-type"] == "application/geo+json" + + +def test_get_search_content_type(app_client): + resp = app_client.get("search") + assert resp.headers["content-type"] == "application/geo+json" + + def test_api_headers(app_client): resp = app_client.get("/api") assert (