-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: fal.App
for multiple endpoints
#27
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
import os | ||
import fal.api | ||
from fal.toolkit import mainify | ||
from fastapi import FastAPI | ||
from typing import Any, NamedTuple, Callable, TypeVar, ClassVar | ||
from fal.logging import get_logger | ||
|
||
EndpointT = TypeVar("EndpointT", bound=Callable[..., Any]) | ||
logger = get_logger(__name__) | ||
|
||
|
||
def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction: | ||
def initialize_and_serve(): | ||
app = cls() | ||
app.serve() | ||
|
||
try: | ||
app = cls(_allow_init=True) | ||
metadata = app.openapi() | ||
except Exception as exc: | ||
logger.warning("Failed to build OpenAPI specification for %s", cls.__name__) | ||
metadata = {} | ||
|
||
wrapper = fal.api.function( | ||
"virtualenv", | ||
requirements=cls.requirements, | ||
machine_type=cls.machine_type, | ||
**cls.host_kwargs, | ||
**kwargs, | ||
metadata=metadata, | ||
serve=True, | ||
) | ||
return wrapper(initialize_and_serve).on( | ||
serve=False, | ||
exposed_port=8080, | ||
) | ||
|
||
|
||
@mainify | ||
class RouteSignature(NamedTuple): | ||
path: str | ||
|
||
|
||
@mainify | ||
class App: | ||
requirements: ClassVar[list[str]] = [] | ||
machine_type: ClassVar[str] = "S" | ||
host_kwargs: ClassVar[dict[str, Any]] = {} | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
cls.host_kwargs = kwargs | ||
|
||
if cls.__init__ is not App.__init__: | ||
raise ValueError( | ||
"App classes should not override __init__ directly. " | ||
"Use setup() instead." | ||
) | ||
|
||
def __init__(self, *, _allow_init: bool = False): | ||
if not _allow_init and not os.getenv("IS_ISOLATE_AGENT"): | ||
raise NotImplementedError( | ||
"Running apps through SDK is not implemented yet." | ||
) | ||
|
||
def setup(self): | ||
"""Setup the application before serving.""" | ||
|
||
def serve(self) -> None: | ||
import uvicorn | ||
|
||
app = self._build_app() | ||
self.setup() | ||
uvicorn.run(app, host="0.0.0.0", port=8080) | ||
|
||
def _build_app(self) -> FastAPI: | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
|
||
_app = FastAPI() | ||
|
||
_app.add_middleware( | ||
CORSMiddleware, | ||
allow_credentials=True, | ||
allow_headers=("*"), | ||
allow_methods=("*"), | ||
allow_origins=("*"), | ||
) | ||
|
||
routes: dict[RouteSignature, Callable[..., Any]] = { | ||
signature: endpoint | ||
for _, endpoint in inspect.getmembers(self, inspect.ismethod) | ||
if (signature := getattr(endpoint, "route_signature", None)) | ||
} | ||
if not routes: | ||
raise ValueError("An application must have at least one route!") | ||
|
||
for signature, endpoint in routes.items(): | ||
_app.add_api_route( | ||
signature.path, | ||
endpoint, | ||
name=endpoint.__name__, | ||
methods=["POST"], | ||
) | ||
|
||
return _app | ||
|
||
def openapi(self) -> dict[str, Any]: | ||
""" | ||
Build the OpenAPI specification for the served function. | ||
Attach needed metadata for a better integration to fal. | ||
""" | ||
app = self._build_app() | ||
spec = app.openapi() | ||
self._mark_order_openapi(spec) | ||
return spec | ||
|
||
def _mark_order_openapi(self, spec: dict[str, Any]): | ||
""" | ||
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI. | ||
|
||
NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties. | ||
""" | ||
|
||
def mark_order(obj: dict[str, Any], key: str): | ||
obj[f"x-fal-order-{key}"] = list(obj[key].keys()) | ||
|
||
mark_order(spec, "paths") | ||
|
||
def order_schema_object(schema: dict[str, Any]): | ||
""" | ||
Mark the order of properties in the schema object. | ||
They can have 'allOf', 'properties' or '$ref' key. | ||
""" | ||
if "allOf" in schema: | ||
for sub_schema in schema["allOf"]: | ||
order_schema_object(sub_schema) | ||
if "properties" in schema: | ||
mark_order(schema, "properties") | ||
|
||
for key in spec["components"].get("schemas") or {}: | ||
order_schema_object(spec["components"]["schemas"][key]) | ||
|
||
return spec | ||
|
||
|
||
@mainify | ||
def endpoint(path: str) -> Callable[[EndpointT], EndpointT]: | ||
"""Designate the decorated function as an application endpoint.""" | ||
|
||
def marker_fn(callable: EndpointT) -> EndpointT: | ||
if hasattr(callable, "route_signature"): | ||
raise ValueError( | ||
f"Can't set multiple routes for the same function: {callable.__name__}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't handle it yet, but in the future I want to be able to support multiple endpoints for a single function by stacking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aha ok 👍 |
||
) | ||
|
||
callable.route_signature = RouteSignature(path=path) # type: ignore | ||
return callable | ||
|
||
return marker_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this whole thing already existed for serve functions, can we merge them?