Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove create_sync_endpoint, run sync functions in background thread #471

Merged
merged 3 commits into from
Oct 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 10 additions & 30 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""fastapi app creation."""
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import attr
from brotli_asgi import BrotliMiddleware
from fastapi import APIRouter, FastAPI
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from pydantic import BaseModel
from stac_pydantic import Collection, Item, ItemCollection
from stac_pydantic.api import ConformanceClasses, LandingPage
from stac_pydantic.api.collections import Collections
Expand All @@ -16,7 +15,6 @@
from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
EmptyRequest,
GeoJSONResponse,
Expand All @@ -25,12 +23,7 @@
create_request_model,
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import (
Scope,
add_route_dependencies,
create_async_endpoint,
create_sync_endpoint,
)
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension
Expand Down Expand Up @@ -113,19 +106,6 @@ def get_extension(self, extension: Type[ApiExtension]) -> Optional[ApiExtension]
return ext
return None

def _create_endpoint(
self,
func: Callable,
request_type: Union[Type[APIRequest], Type[BaseModel]],
resp_class: Type[Response],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseCoreClient):
return create_async_endpoint(func, request_type, response_class=resp_class)
elif isinstance(self.client, BaseCoreClient):
return create_sync_endpoint(func, request_type, response_class=resp_class)
raise NotImplementedError

def register_landing_page(self):
"""Register landing page (GET /).

Expand All @@ -142,7 +122,7 @@ def register_landing_page(self):
response_model_exclude_unset=False,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.landing_page, EmptyRequest, self.response_class
),
)
Expand All @@ -163,7 +143,7 @@ def register_conformance_classes(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.conformance, EmptyRequest, self.response_class
),
)
Expand All @@ -182,7 +162,7 @@ def register_get_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.get_item, ItemUri, self.response_class
),
)
Expand All @@ -204,7 +184,7 @@ def register_post_search(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.post_search, self.search_post_request_model, GeoJSONResponse
),
)
Expand All @@ -226,7 +206,7 @@ def register_get_search(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.get_search, self.search_get_request_model, GeoJSONResponse
),
)
Expand All @@ -247,7 +227,7 @@ def register_get_collections(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.all_collections, EmptyRequest, self.response_class
),
)
Expand All @@ -266,7 +246,7 @@ def register_get_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.get_collection, CollectionUri, self.response_class
),
)
Expand Down Expand Up @@ -297,7 +277,7 @@ def register_get_item_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["GET"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.item_collection, request_model, self.response_class
),
)
Expand Down
59 changes: 20 additions & 39 deletions stac_fastapi/api/stac_fastapi/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""route factories."""
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union

from fastapi import Depends, params
from fastapi.dependencies.utils import get_parameterless_sub_dependant
from pydantic import BaseModel
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import BaseRoute, Match
Expand All @@ -21,12 +24,28 @@ def _wrap_response(resp: Any, response_class: Type[Response]) -> Response:
return Response(status_code=HTTP_204_NO_CONTENT)


def sync_to_async(func):
"""Run synchronous function asynchronously in a background thread."""

@functools.wraps(func)
async def run(*args, **kwargs):
return await run_in_threadpool(func, *args, **kwargs)

return run


def create_async_endpoint(
func: Callable,
request_model: Union[Type[APIRequest], Type[BaseModel], Dict],
response_class: Type[Response] = JSONResponse,
):
"""Wrap a coroutine in another coroutine which may be used to create a FastAPI endpoint."""
"""Wrap a function in a coroutine which may be used to create a FastAPI endpoint.

Synchronous functions are executed asynchronously using a background thread.
"""
if not inspect.iscoroutinefunction(func):
func = sync_to_async(func)

if issubclass(request_model, APIRequest):

async def _endpoint(
Expand Down Expand Up @@ -63,44 +82,6 @@ async def _endpoint(
return _endpoint


def create_sync_endpoint(
func: Callable,
request_model: Union[Type[APIRequest], Type[BaseModel], Dict],
response_class: Type[Response] = JSONResponse,
):
"""Wrap a function in another function which may be used to create a FastAPI endpoint."""
if issubclass(request_model, APIRequest):

def _endpoint(
request: Request,
request_data: request_model = Depends(), # type:ignore
):
"""Endpoint."""
return _wrap_response(
func(request=request, **request_data.kwargs()), response_class
)

elif issubclass(request_model, BaseModel):

def _endpoint(
request: Request,
request_data: request_model, # type:ignore
):
"""Endpoint."""
return _wrap_response(func(request_data, request=request), response_class)

else:

def _endpoint(
request: Request,
request_data: Dict[str, Any], # type:ignore
):
"""Endpoint."""
return _wrap_response(func(request_data, request=request), response_class)

return _endpoint


class Scope(TypedDict, total=False):
"""More strict version of Starlette's Scope."""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
# encoding: utf-8
"""Filter Extension."""
from enum import Enum
from typing import Callable, List, Type, Union
from typing import List, Type, Union

import attr
from fastapi import APIRouter, FastAPI
from starlette.responses import Response

from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
EmptyRequest,
JSONSchemaResponse,
)
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
from stac_fastapi.api.models import CollectionUri, EmptyRequest, JSONSchemaResponse
from stac_fastapi.api.routes import create_async_endpoint
from stac_fastapi.types.core import AsyncBaseFiltersClient, BaseFiltersClient
from stac_fastapi.types.extension import ApiExtension

Expand Down Expand Up @@ -80,24 +75,6 @@ class FilterExtension(ApiExtension):
router: APIRouter = attr.ib(factory=APIRouter)
response_class: Type[Response] = attr.ib(default=JSONSchemaResponse)

def _create_endpoint(
self,
func: Callable,
request_type: Union[
Type[APIRequest],
],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseFiltersClient):
return create_async_endpoint(
func, request_type, response_class=self.response_class
)
if isinstance(self.client, BaseFiltersClient):
return create_sync_endpoint(
func, request_type, response_class=self.response_class
)
raise NotImplementedError

def register(self, app: FastAPI) -> None:
"""Register the extension with a FastAPI application.

Expand All @@ -112,12 +89,16 @@ def register(self, app: FastAPI) -> None:
name="Queryables",
path="/queryables",
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_queryables, EmptyRequest),
endpoint=create_async_endpoint(
self.client.get_queryables, EmptyRequest, self.response_class
),
)
self.router.add_api_route(
name="Collection Queryables",
path="/collections/{collection_id}/queryables",
methods=["GET"],
endpoint=self._create_endpoint(self.client.get_queryables, CollectionUri),
endpoint=create_async_endpoint(
self.client.get_queryables, CollectionUri, self.response_class
),
)
app.include_router(self.router, tags=["Filter Extension"])
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""transaction extension."""
from typing import Callable, List, Optional, Type, Union
from typing import List, Optional, Type, Union

import attr
from fastapi import APIRouter, Body, FastAPI
from pydantic import BaseModel
from stac_pydantic import Collection, Item
from starlette.responses import JSONResponse, Response

from stac_fastapi.api.models import APIRequest, CollectionUri, ItemUri
from stac_fastapi.api.routes import create_async_endpoint, create_sync_endpoint
from stac_fastapi.api.models import CollectionUri, ItemUri
from stac_fastapi.api.routes import create_async_endpoint
from stac_fastapi.types import stac as stac_types
from stac_fastapi.types.config import ApiSettings
from stac_fastapi.types.core import AsyncBaseTransactionsClient, BaseTransactionsClient
Expand Down Expand Up @@ -60,27 +59,6 @@ class TransactionExtension(ApiExtension):
router: APIRouter = attr.ib(factory=APIRouter)
response_class: Type[Response] = attr.ib(default=JSONResponse)

def _create_endpoint(
self,
func: Callable,
request_type: Union[
Type[APIRequest],
Type[BaseModel],
Type[stac_types.Item],
Type[stac_types.Collection],
],
) -> Callable:
"""Create a FastAPI endpoint."""
if isinstance(self.client, AsyncBaseTransactionsClient):
return create_async_endpoint(
func, request_type, response_class=self.response_class
)
elif isinstance(self.client, BaseTransactionsClient):
return create_sync_endpoint(
func, request_type, response_class=self.response_class
)
raise NotImplementedError

def register_create_item(self):
"""Register create item endpoint (POST /collections/{collection_id}/items)."""
self.router.add_api_route(
Expand All @@ -91,7 +69,7 @@ def register_create_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(self.client.create_item, PostItem),
endpoint=create_async_endpoint(self.client.create_item, PostItem),
)

def register_update_item(self):
Expand All @@ -104,7 +82,7 @@ def register_update_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["PUT"],
endpoint=self._create_endpoint(self.client.update_item, PutItem),
endpoint=create_async_endpoint(self.client.update_item, PutItem),
)

def register_delete_item(self):
Expand All @@ -117,7 +95,7 @@ def register_delete_item(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["DELETE"],
endpoint=self._create_endpoint(self.client.delete_item, ItemUri),
endpoint=create_async_endpoint(self.client.delete_item, ItemUri),
)

def register_create_collection(self):
Expand All @@ -130,7 +108,7 @@ def register_create_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["POST"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.create_collection, stac_types.Collection
),
)
Expand All @@ -145,7 +123,7 @@ def register_update_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["PUT"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.update_collection, stac_types.Collection
),
)
Expand All @@ -160,7 +138,7 @@ def register_delete_collection(self):
response_model_exclude_unset=True,
response_model_exclude_none=True,
methods=["DELETE"],
endpoint=self._create_endpoint(
endpoint=create_async_endpoint(
self.client.delete_collection, CollectionUri
),
)
Expand Down
Loading