Skip to content

Commit

Permalink
Set content-type to geojson for search results (#288)
Browse files Browse the repository at this point in the history
* Set content-type to geojson for search results

* Update changelog

* Test for ORJSON availability

Co-authored-by: Rob Emanuele <rdemanuele@gmail.com>
  • Loading branch information
moradology and lossyrob authored Nov 30, 2021
1 parent 185b09c commit c844034
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

### Fixed

* Content-type response headers for the /search endpoint now reflect the geojson response expected in the STAC api spec ([#220](https://github.com/stac-utils/stac-fastapi/issues/220)
* The minimum `limit` value for searches is now 1 ([#296](https://github.com/stac-utils/stac-fastapi/pull/296))
* Links stored with Collections and Items (e.g. license links) are now returned with those STAC objects ([#282](https://github.com/stac-utils/stac-fastapi/pull/282))
* Content-type response headers for the /api endpoint now reflect those expected in the STAC api spec ([#287](https://github.com/stac-utils/stac-fastapi/pull/287))
Expand Down
46 changes: 29 additions & 17 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
APIRequest,
CollectionUri,
EmptyRequest,
GeoJSONResponse,
ItemCollectionUri,
ItemUri,
SearchGetRequest,
Expand Down Expand Up @@ -96,17 +97,16 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]
return None

def _create_endpoint(
self, func: Callable, request_type: Union[Type[APIRequest], Type[BaseModel]]
self,
func: Callable,
request_type: Union[Type[APIRequest], Type[BaseModel]],
resp_class: Type[Response],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseCoreClient):
return create_async_endpoint(
func, request_type, response_class=self.response_class
)
return create_async_endpoint(func, request_type, response_class=resp_class)
elif isinstance(self.client, BaseCoreClient):
return create_sync_endpoint(
func, request_type, response_class=self.response_class
)
return create_sync_endpoint(func, request_type, response_class=resp_class)
raise NotImplementedError

def register_landing_page(self):
Expand All @@ -125,7 +125,9 @@ def register_landing_page(self):
response_model_exclude_unset=False,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.landing_page, EmptyRequest),
endpoint=self._create_endpoint(
self.client.landing_page, EmptyRequest, self.response_class
),
)

def register_conformance_classes(self):
Expand All @@ -144,7 +146,9 @@ def register_conformance_classes(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.conformance, EmptyRequest),
endpoint=self._create_endpoint(
self.client.conformance, EmptyRequest, self.response_class
),
)

def register_get_item(self):
Expand All @@ -161,7 +165,9 @@ def register_get_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_item, ItemUri),
endpoint=self._create_endpoint(
self.client.get_item, ItemUri, self.response_class
),
)

def register_post_search(self):
Expand All @@ -178,12 +184,12 @@ def register_post_search(self):
response_model=(ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None,
response_class=self.response_class,
response_class=GeoJSONResponse,
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
self.client.post_search, search_request_model
self.client.post_search, search_request_model, GeoJSONResponse
),
)

Expand All @@ -200,12 +206,12 @@ def register_get_search(self):
response_model=(ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None,
response_class=self.response_class,
response_class=GeoJSONResponse,
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.get_search, self.search_get_request
self.client.get_search, self.search_get_request, GeoJSONResponse
),
)

Expand All @@ -225,7 +231,9 @@ def register_get_collections(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.all_collections, EmptyRequest),
endpoint=self._create_endpoint(
self.client.all_collections, EmptyRequest, self.response_class
),
)

def register_get_collection(self):
Expand All @@ -242,7 +250,9 @@ def register_get_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_collection, CollectionUri),
endpoint=self._create_endpoint(
self.client.get_collection, CollectionUri, self.response_class
),
)

def register_get_item_collection(self):
Expand All @@ -262,7 +272,9 @@ def register_get_item_collection(self):
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
self.client.item_collection, self.item_collection_uri
self.client.item_collection,
self.item_collection_uri,
self.response_class,
),
)

Expand Down
20 changes: 20 additions & 0 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""api request/response models."""

import abc
import importlib
from typing import Dict, Optional, Type, Union

import attr
Expand Down Expand Up @@ -127,3 +128,22 @@ def kwargs(self) -> Dict:
"fields": self.fields.split(",") if self.fields else self.fields,
"sortby": self.sortby.split(",") if self.sortby else self.sortby,
}


# Test for ORJSON and use it rather than stdlib JSON where supported
if importlib.util.find_spec("orjson") is not None:
from fastapi.responses import ORJSONResponse

class GeoJSONResponse(ORJSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/geo+json"


else:
from starlette.responses import JSONResponse

class GeoJSONResponse(JSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/geo+json"
13 changes: 13 additions & 0 deletions stac_fastapi/pgstac/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@
]


@pytest.mark.asyncio
async def test_post_search_content_type(app_client):
params = {"limit": 1}
resp = await app_client.post("search", json=params)
assert resp.headers["content-type"] == "application/geo+json"


@pytest.mark.asyncio
async def test_get_search_content_type(app_client):
resp = await app_client.get("search")
assert resp.headers["content-type"] == "application/geo+json"


@pytest.mark.asyncio
async def test_api_headers(app_client):
resp = await app_client.get("/api")
Expand Down
11 changes: 11 additions & 0 deletions stac_fastapi/sqlalchemy/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
]


def test_post_search_content_type(app_client):
params = {"limit": 1}
resp = app_client.post("search", json=params)
assert resp.headers["content-type"] == "application/geo+json"


def test_get_search_content_type(app_client):
resp = app_client.get("search")
assert resp.headers["content-type"] == "application/geo+json"


def test_api_headers(app_client):
resp = app_client.get("/api")
assert (
Expand Down

0 comments on commit c844034

Please sign in to comment.