Skip to content

Commit

Permalink
Pass request by name into methods (#44)
Browse files Browse the repository at this point in the history
* fix: pass request by name into methods

This makes #22 less breaking.

* chore: update changelog
  • Loading branch information
gadomski authored Jun 20, 2023
1 parent cbe55e0 commit 659e46f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Fixed

- Pass `request` by name when calling endpoints from other endpoints ([#44](https://github.com/stac-utils/stac-fastapi-pgstac/pull/44))

## [2.4.8] - 2023-06-08

### Changed
Expand Down
12 changes: 6 additions & 6 deletions stac_fastapi/pgstac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _add_item_links(
if settings.use_api_hydrate:

async def _get_base_item(collection_id: str) -> Dict[str, Any]:
return await self._get_base_item(collection_id, request)
return await self._get_base_item(collection_id, request=request)

base_item_cache = settings.base_item_cache(
fetch_base_item=_get_base_item, request=request
Expand Down Expand Up @@ -267,7 +267,7 @@ async def item_collection(
An ItemCollection.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request)
await self.get_collection(collection_id, request=request)

base_args = {
"collections": [collection_id],
Expand All @@ -285,7 +285,7 @@ async def item_collection(
search_request = self.post_request_model(
**clean,
)
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
links = await ItemCollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
Expand All @@ -307,12 +307,12 @@ async def get_item(
Item.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request)
await self.get_collection(collection_id, request=request)

search_request = self.post_request_model(
ids=[item_id], collections=[collection_id], limit=1
)
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
if not item_collection["features"]:
raise NotFoundError(
f"Item {item_id} in Collection {collection_id} does not exist."
Expand All @@ -333,7 +333,7 @@ async def post_search(
Returns:
ItemCollection containing items which match the search criteria.
"""
item_collection = await self._search_base(search_request, request)
item_collection = await self._search_base(search_request, request=request)
return ItemCollection(**item_collection)

async def get_search(
Expand Down
84 changes: 83 additions & 1 deletion tests/api/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List
from typing import Any, Callable, Coroutine, Dict, List, Optional, TypeVar
from urllib.parse import quote_plus

import orjson
import pytest
from fastapi import Request
from httpx import AsyncClient
from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent
from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_post_request_model
from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension
from stac_fastapi.types import stac as stac_types

from stac_fastapi.pgstac.core import CoreCrudClient, Settings
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.transactions import TransactionsClient
from stac_fastapi.pgstac.types.search import PgstacSearch

STAC_CORE_ROUTES = [
"GET /",
Expand Down Expand Up @@ -622,3 +633,74 @@ async def search(query: Dict[str, Any]) -> List[Item]:
}
items = await search(query)
assert len(items) == 10, items


@pytest.mark.asyncio
async def test_wrapped_function(load_test_data) -> None:
# Ensure wrappers, e.g. Planetary Computer's rate limiting, work.
# https://github.com/gadomski/planetary-computer-apis/blob/2719ccf6ead3e06de0784c39a2918d4d1811368b/pccommon/pccommon/redis.py#L205-L238

T = TypeVar("T")

def wrap() -> (
Callable[
[Callable[..., Coroutine[Any, Any, T]]],
Callable[..., Coroutine[Any, Any, T]],
]
):
def decorator(
fn: Callable[..., Coroutine[Any, Any, T]]
) -> Callable[..., Coroutine[Any, Any, T]]:
async def _wrapper(*args: Any, **kwargs: Any) -> T:
request: Optional[Request] = kwargs.get("request")
if request:
pass # This is where rate limiting would be applied
else:
raise ValueError(f"Missing request in {fn.__name__}")
return await fn(*args, **kwargs)

return _wrapper

return decorator

class Client(CoreCrudClient):
@wrap()
async def get_collection(
self, collection_id: str, request: Request, **kwargs
) -> stac_types.Item:
return await super().get_collection(
collection_id, request=request, **kwargs
)

settings = Settings(testing=True)
extensions = [
TransactionExtension(client=TransactionsClient(), settings=settings),
FieldsExtension(),
]
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
api = StacApi(
client=Client(post_request_model=post_request_model),
settings=settings,
extensions=extensions,
search_post_request_model=post_request_model,
)
app = api.app
await connect_to_db(app)
try:
async with AsyncClient(app=app) as client:
response = await client.post(
"http://test/collections",
json=load_test_data("test_collection.json"),
)
assert response.status_code == 200
response = await client.post(
"http://test/collections/test-collection/items",
json=load_test_data("test_item.json"),
)
assert response.status_code == 200
response = await client.get(
"http://test/collections/test-collection/items/test-item"
)
assert response.status_code == 200
finally:
await close_db_connection(app)

0 comments on commit 659e46f

Please sign in to comment.