diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index d18844e5c..aff21be58 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -1,12 +1,11 @@ """fastapi app creation.""" -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import attr from brotli_asgi import BrotliMiddleware from fastapi import APIRouter, FastAPI from fastapi.openapi.utils import get_openapi from fastapi.params import Depends -from pydantic import BaseModel from stac_pydantic import Collection, Item, ItemCollection from stac_pydantic.api import ConformanceClasses, LandingPage from stac_pydantic.api.collections import Collections @@ -16,7 +15,6 @@ from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( - APIRequest, CollectionUri, EmptyRequest, GeoJSONResponse, @@ -25,12 +23,7 @@ create_request_model, ) from stac_fastapi.api.openapi import update_openapi -from stac_fastapi.api.routes import ( - Scope, - add_route_dependencies, - create_async_endpoint, - create_sync_endpoint, -) +from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint # TODO: make this module not depend on `stac_fastapi.extensions` from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension @@ -113,19 +106,6 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension] return ext return None - def _create_endpoint( - self, - func: Callable, - request_type: Union[Type[APIRequest], Type[BaseModel]], - resp_class: Type[Response], - ) -> Callable: - """Create a FastAPI endpoint.""" - if isinstance(self.client, AsyncBaseCoreClient): - return create_async_endpoint(func, request_type, response_class=resp_class) - elif isinstance(self.client, BaseCoreClient): - return create_sync_endpoint(func, request_type, response_class=resp_class) - raise NotImplementedError - def register_landing_page(self): """Register landing page (GET /). @@ -142,7 +122,7 @@ def register_landing_page(self): response_model_exclude_unset=False, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.landing_page, EmptyRequest, self.response_class ), ) @@ -163,7 +143,7 @@ def register_conformance_classes(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.conformance, EmptyRequest, self.response_class ), ) @@ -182,7 +162,7 @@ def register_get_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.get_item, ItemUri, self.response_class ), ) @@ -204,7 +184,7 @@ def register_post_search(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.post_search, self.search_post_request_model, GeoJSONResponse ), ) @@ -226,7 +206,7 @@ def register_get_search(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.get_search, self.search_get_request_model, GeoJSONResponse ), ) @@ -247,7 +227,7 @@ def register_get_collections(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.all_collections, EmptyRequest, self.response_class ), ) @@ -266,7 +246,7 @@ def register_get_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.get_collection, CollectionUri, self.response_class ), ) @@ -297,7 +277,7 @@ def register_get_item_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.item_collection, request_model, self.response_class ), ) diff --git a/stac_fastapi/api/stac_fastapi/api/routes.py b/stac_fastapi/api/stac_fastapi/api/routes.py index 941f05b0b..3c0186564 100644 --- a/stac_fastapi/api/stac_fastapi/api/routes.py +++ b/stac_fastapi/api/stac_fastapi/api/routes.py @@ -1,9 +1,12 @@ """route factories.""" +import functools +import inspect from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union from fastapi import Depends, params from fastapi.dependencies.utils import get_parameterless_sub_dependant from pydantic import BaseModel +from starlette.concurrency import run_in_threadpool from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import BaseRoute, Match @@ -21,12 +24,28 @@ def _wrap_response(resp: Any, response_class: Type[Response]) -> Response: return Response(status_code=HTTP_204_NO_CONTENT) +def sync_to_async(func): + """Run synchronous function asynchronously in a background thread.""" + + @functools.wraps(func) + async def run(*args, **kwargs): + return await run_in_threadpool(func, *args, **kwargs) + + return run + + def create_async_endpoint( func: Callable, request_model: Union[Type[APIRequest], Type[BaseModel], Dict], response_class: Type[Response] = JSONResponse, ): - """Wrap a coroutine in another coroutine which may be used to create a FastAPI endpoint.""" + """Wrap a function in a coroutine which may be used to create a FastAPI endpoint. + + Synchronous functions are executed asynchronously using a background thread. + """ + if not inspect.iscoroutinefunction(func): + func = sync_to_async(func) + if issubclass(request_model, APIRequest): async def _endpoint( @@ -63,44 +82,6 @@ async def _endpoint( return _endpoint -def create_sync_endpoint( - func: Callable, - request_model: Union[Type[APIRequest], Type[BaseModel], Dict], - response_class: Type[Response] = JSONResponse, -): - """Wrap a function in another function which may be used to create a FastAPI endpoint.""" - if issubclass(request_model, APIRequest): - - def _endpoint( - request: Request, - request_data: request_model = Depends(), # type:ignore - ): - """Endpoint.""" - return _wrap_response( - func(request=request, **request_data.kwargs()), response_class - ) - - elif issubclass(request_model, BaseModel): - - def _endpoint( - request: Request, - request_data: request_model, # type:ignore - ): - """Endpoint.""" - return _wrap_response(func(request_data, request=request), response_class) - - else: - - def _endpoint( - request: Request, - request_data: Dict[str, Any], # type:ignore - ): - """Endpoint.""" - return _wrap_response(func(request_data, request=request), response_class) - - return _endpoint - - class Scope(TypedDict, total=False): """More strict version of Starlette's Scope.""" diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py index 0854c9f4f..b9eebb708 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/filter/filter.py @@ -1,19 +1,14 @@ # encoding: utf-8 """Filter Extension.""" from enum import Enum -from typing import Callable, List, Type, Union +from typing import List, Type, Union import attr from fastapi import APIRouter, FastAPI from starlette.responses import Response -from stac_fastapi.api.models import ( - APIRequest, - CollectionUri, - EmptyRequest, - JSONSchemaResponse, -) -from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint +from stac_fastapi.api.models import CollectionUri, EmptyRequest, JSONSchemaResponse +from stac_fastapi.api.routes import create_async_endpoint from stac_fastapi.types.core import AsyncBaseFiltersClient, BaseFiltersClient from stac_fastapi.types.extension import ApiExtension @@ -80,24 +75,6 @@ class FilterExtension(ApiExtension): router: APIRouter = attr.ib(factory=APIRouter) response_class: Type[Response] = attr.ib(default=JSONSchemaResponse) - def _create_endpoint( - self, - func: Callable, - request_type: Union[ - Type[APIRequest], - ], - ) -> Callable: - """Create a FastAPI endpoint.""" - if isinstance(self.client, AsyncBaseFiltersClient): - return create_async_endpoint( - func, request_type, response_class=self.response_class - ) - if isinstance(self.client, BaseFiltersClient): - return create_sync_endpoint( - func, request_type, response_class=self.response_class - ) - raise NotImplementedError - def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. @@ -112,12 +89,16 @@ def register(self, app: FastAPI) -> None: name="Queryables", path="/queryables", methods=["GET"], - endpoint=self._create_endpoint(self.client.get_queryables, EmptyRequest), + endpoint=create_async_endpoint( + self.client.get_queryables, EmptyRequest, self.response_class + ), ) self.router.add_api_route( name="Collection Queryables", path="/collections/{collection_id}/queryables", methods=["GET"], - endpoint=self._create_endpoint(self.client.get_queryables, CollectionUri), + endpoint=create_async_endpoint( + self.client.get_queryables, CollectionUri, self.response_class + ), ) app.include_router(self.router, tags=["Filter Extension"]) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py index 5967e7128..46c4568b3 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/transaction.py @@ -1,14 +1,13 @@ """transaction extension.""" -from typing import Callable, List, Optional, Type, Union +from typing import List, Optional, Type, Union import attr from fastapi import APIRouter, Body, FastAPI -from pydantic import BaseModel from stac_pydantic import Collection, Item from starlette.responses import JSONResponse, Response -from stac_fastapi.api.models import APIRequest, CollectionUri, ItemUri -from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint +from stac_fastapi.api.models import CollectionUri, ItemUri +from stac_fastapi.api.routes import create_async_endpoint from stac_fastapi.types import stac as stac_types from stac_fastapi.types.config import ApiSettings from stac_fastapi.types.core import AsyncBaseTransactionsClient, BaseTransactionsClient @@ -60,27 +59,6 @@ class TransactionExtension(ApiExtension): router: APIRouter = attr.ib(factory=APIRouter) response_class: Type[Response] = attr.ib(default=JSONResponse) - def _create_endpoint( - self, - func: Callable, - request_type: Union[ - Type[APIRequest], - Type[BaseModel], - Type[stac_types.Item], - Type[stac_types.Collection], - ], - ) -> Callable: - """Create a FastAPI endpoint.""" - if isinstance(self.client, AsyncBaseTransactionsClient): - return create_async_endpoint( - func, request_type, response_class=self.response_class - ) - elif isinstance(self.client, BaseTransactionsClient): - return create_sync_endpoint( - func, request_type, response_class=self.response_class - ) - raise NotImplementedError - def register_create_item(self): """Register create item endpoint (POST /collections/{collection_id}/items).""" self.router.add_api_route( @@ -91,7 +69,7 @@ def register_create_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=self._create_endpoint(self.client.create_item, PostItem), + endpoint=create_async_endpoint(self.client.create_item, PostItem), ) def register_update_item(self): @@ -104,7 +82,7 @@ def register_update_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["PUT"], - endpoint=self._create_endpoint(self.client.update_item, PutItem), + endpoint=create_async_endpoint(self.client.update_item, PutItem), ) def register_delete_item(self): @@ -117,7 +95,7 @@ def register_delete_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["DELETE"], - endpoint=self._create_endpoint(self.client.delete_item, ItemUri), + endpoint=create_async_endpoint(self.client.delete_item, ItemUri), ) def register_create_collection(self): @@ -130,7 +108,7 @@ def register_create_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.create_collection, stac_types.Collection ), ) @@ -145,7 +123,7 @@ def register_update_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["PUT"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.update_collection, stac_types.Collection ), ) @@ -160,7 +138,7 @@ def register_delete_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["DELETE"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.delete_collection, CollectionUri ), ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py index 3fe25c9d1..1bc104059 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/third_party/bulk_transactions.py @@ -1,15 +1,14 @@ """bulk transactions extension.""" import abc -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Union import attr from fastapi import APIRouter, FastAPI from pydantic import BaseModel from stac_fastapi.api.models import create_request_model -from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint +from stac_fastapi.api.routes import create_async_endpoint from stac_fastapi.types.extension import ApiExtension -from stac_fastapi.types.search import APIRequest class Items(BaseModel): @@ -93,18 +92,6 @@ class BulkTransactionExtension(ApiExtension): conformance_classes: List[str] = attr.ib(default=list()) schema_href: Optional[str] = attr.ib(default=None) - def _create_endpoint( - self, - func: Callable, - request_type: Union[Type[APIRequest], Type[BaseModel], Dict], - ) -> Callable: - """Create a FastAPI endpoint.""" - if isinstance(self.client, AsyncBaseBulkTransactionsClient): - return create_async_endpoint(func, request_type) - elif isinstance(self.client, BaseBulkTransactionsClient): - return create_sync_endpoint(func, request_type) - raise NotImplementedError - def register(self, app: FastAPI) -> None: """Register the extension with a FastAPI application. @@ -124,7 +111,7 @@ def register(self, app: FastAPI) -> None: response_model_exclude_unset=True, response_model_exclude_none=True, methods=["POST"], - endpoint=self._create_endpoint( + endpoint=create_async_endpoint( self.client.bulk_item_insert, items_request_model ), )