diff --git a/CHANGES.md b/CHANGES.md index 1b93d96..7d2912b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- Pass `request` by name when calling endpoints from other endpoints ([#44](https://github.com/stac-utils/stac-fastapi-pgstac/pull/44)) + ## [2.4.8] - 2023-06-08 ### Changed diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7e1b4a3..ef20f25 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -213,7 +213,7 @@ async def _add_item_links( if settings.use_api_hydrate: async def _get_base_item(collection_id: str) -> Dict[str, Any]: - return await self._get_base_item(collection_id, request) + return await self._get_base_item(collection_id, request=request) base_item_cache = settings.base_item_cache( fetch_base_item=_get_base_item, request=request @@ -267,7 +267,7 @@ async def item_collection( An ItemCollection. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, request) + await self.get_collection(collection_id, request=request) base_args = { "collections": [collection_id], @@ -285,7 +285,7 @@ async def item_collection( search_request = self.post_request_model( **clean, ) - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) links = await ItemCollectionLinks( collection_id=collection_id, request=request ).get_links(extra_links=item_collection["links"]) @@ -307,12 +307,12 @@ async def get_item( Item. """ # If collection does not exist, NotFoundError wil be raised - await self.get_collection(collection_id, request) + await self.get_collection(collection_id, request=request) search_request = self.post_request_model( ids=[item_id], collections=[collection_id], limit=1 ) - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) if not item_collection["features"]: raise NotFoundError( f"Item {item_id} in Collection {collection_id} does not exist." @@ -333,7 +333,7 @@ async def post_search( Returns: ItemCollection containing items which match the search criteria. """ - item_collection = await self._search_base(search_request, request) + item_collection = await self._search_base(search_request, request=request) return ItemCollection(**item_collection) async def get_search( diff --git a/tests/api/test_api.py b/tests/api/test_api.py index ae766eb..02f9505 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,10 +1,21 @@ from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar from urllib.parse import quote_plus import orjson import pytest +from fastapi import Request +from httpx import AsyncClient from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent +from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import create_post_request_model +from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension +from stac_fastapi.types import stac as stac_types + +from stac_fastapi.pgstac.core import CoreCrudClient, Settings +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db +from stac_fastapi.pgstac.transactions import TransactionsClient +from stac_fastapi.pgstac.types.search import PgstacSearch STAC_CORE_ROUTES = [ "GET /", @@ -622,3 +633,74 @@ async def search(query: Dict[str, Any]) -> List[Item]: } items = await search(query) assert len(items) == 10, items + + +@pytest.mark.asyncio +async def test_wrapped_function(load_test_data) -> None: + # Ensure wrappers, e.g. Planetary Computer's rate limiting, work. + # https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238 + + T = TypeVar("T") + + def wrap() -> ( + Callable[ + [Callable[..., Coroutine[Any, Any, T]]], + Callable[..., Coroutine[Any, Any, T]], + ] + ): + def decorator( + fn: Callable[..., Coroutine[Any, Any, T]] + ) -> Callable[..., Coroutine[Any, Any, T]]: + async def _wrapper(*args: Any, **kwargs: Any) -> T: + request: Optional[Request] = kwargs.get("request") + if request: + pass # This is where rate limiting would be applied + else: + raise ValueError(f"Missing request in {fn.__name__}") + return await fn(*args, **kwargs) + + return _wrapper + + return decorator + + class Client(CoreCrudClient): + @wrap() + async def get_collection( + self, collection_id: str, request: Request, **kwargs + ) -> stac_types.Item: + return await super().get_collection( + collection_id, request=request, **kwargs + ) + + settings = Settings(testing=True) + extensions = [ + TransactionExtension(client=TransactionsClient(), settings=settings), + FieldsExtension(), + ] + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( + client=Client(post_request_model=post_request_model), + settings=settings, + extensions=extensions, + search_post_request_model=post_request_model, + ) + app = api.app + await connect_to_db(app) + try: + async with AsyncClient(app=app) as client: + response = await client.post( + "http://test/collections", + json=load_test_data("test_collection.json"), + ) + assert response.status_code == 200 + response = await client.post( + "http://test/collections/test-collection/items", + json=load_test_data("test_item.json"), + ) + assert response.status_code == 200 + response = await client.get( + "http://test/collections/test-collection/items/test-item" + ) + assert response.status_code == 200 + finally: + await close_db_connection(app)