Skip to content

Commit

Permalink
wip: feat: fal.App for multiple endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Jan 9, 2024
1 parent d943e76 commit 6724ce2
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 7 deletions.
1 change: 1 addition & 0 deletions projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fal.api import FalServerlessHost, LocalHost, cached
from fal.api import function
from fal.api import function as isolated
from fal.app import App, endpoint, wrap_app
from fal.sdk import FalServerlessKeyCredentials
from fal.sync import sync_dir

Expand Down
147 changes: 147 additions & 0 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

import inspect
import os
import fal.api
from fal.toolkit import mainify
from fastapi import FastAPI
from typing import Any, NamedTuple, Callable, TypeVar, NoReturn, Type, ClassVar


EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])


def wrap_app(cls: Type[App], **kwargs) -> fal.api.IsolatedFunction:
def initialize_and_serve():
app = cls()
app.serve()

wrapper = fal.api.function(
"virtualenv",
requirements=cls.requirements,
machine_type=cls.machine_type,
**cls.host_kwargs,
**kwargs,
serve=True,
)
return wrapper(initialize_and_serve).on(
serve=False,
exposed_port=8080,
)


@mainify
class RouteSignature(NamedTuple):
path: str


@mainify
class App:
requirements: ClassVar[list[str]] = []
machine_type: ClassVar[str] = "S"
host_kwargs: ClassVar[dict[str, Any]] = {}

def __init_subclass__(cls, **kwargs):
cls.host_kwargs = kwargs

def __init__(self):
if not os.getenv("IS_ISOLATE_AGENT"):
raise NotImplementedError(
"Running apps through SDK is not implemented yet."
)

def setup(self):
"""Setup the application before serving."""

def serve(self) -> NoReturn:
import uvicorn

app = self._build_app()
self.setup()
uvicorn.run(app, host="0.0.0.0", port=8080)

def _build_app(self) -> FastAPI:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

_app = FastAPI()

_app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=("*"),
allow_methods=("*"),
allow_origins=("*"),
)

routes: dict[RouteSignature, Callable[..., Any]] = {
signature: endpoint
for _, endpoint in inspect.getmembers(self, inspect.ismethod)
if (signature := getattr(endpoint, "route_signature", None))
}
if not routes:
raise ValueError("An application must have at least one route!")

for signature, endpoint in routes.items():
_app.add_api_route(
signature.path,
endpoint,
name=endpoint.__name__,
methods=["POST"],
)

return _app

def openapi(self) -> dict[str, Any]:
"""
Build the OpenAPI specification for the served function.
Attach needed metadata for a better integration to fal.
"""
app = self._build_app()
spec = app.openapi()
self._mark_order_openapi(spec)
return spec

def _mark_order_openapi(self, spec: dict[str, Any]):
"""
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
"""

def mark_order(obj: dict[str, Any], key: str):
obj[f"x-fal-order-{key}"] = list(obj[key].keys())

mark_order(spec, "paths")

def order_schema_object(schema: dict[str, Any]):
"""
Mark the order of properties in the schema object.
They can have 'allOf', 'properties' or '$ref' key.
"""
if "allOf" in schema:
for sub_schema in schema["allOf"]:
order_schema_object(sub_schema)
if "properties" in schema:
mark_order(schema, "properties")

for key in spec["components"].get("schemas") or {}:
order_schema_object(spec["components"]["schemas"][key])

return spec


@mainify
def endpoint(path: str) -> Callable[[EndpointT], EndpointT]:
"""Designate the decorated function as an application endpoint."""

def marker_fn(callable: EndpointT) -> EndpointT:
if hasattr(callable, "route_signature"):
raise ValueError(
f"Can't set multiple routes for the same function: {callable.__name__}"
)

callable.route_signature = RouteSignature(path=path)
return callable

return marker_fn
40 changes: 33 additions & 7 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import click
import fal.auth as auth
import grpc
import fal
from fal import api, sdk
from fal.console import console
from fal.exceptions import ApplicationExceptionHandler
Expand Down Expand Up @@ -244,6 +245,28 @@ def function_cli(ctx, host: str, port: str):
ctx.obj = api.FalServerlessHost(f"{host}:{port}")


def load_function_from(
host: api.FalServerlessHost,
file_path: str,
function_name: str,
) -> api.IsolatedFunction:
import runpy

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

target = module[function_name]
if issubclass(target, fal.App):
target = fal.wrap_app(target, host=host)

if not isinstance(target, api.IsolatedFunction):
raise api.FalServerlessError(
f"Function '{function_name}' is not a fal.function or a fal.App"
)
return target


@function_cli.command("serve")
@click.option("--alias", default=None)
@click.option(
Expand All @@ -262,15 +285,9 @@ def register_application(
alias: str | None,
auth_mode: ALIAS_AUTH_TYPE,
):
import runpy

user_id = _get_user_id()

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

isolated_function: api.IsolatedFunction = module[function_name]
isolated_function = load_function_from(host, file_path, function_name)
gateway_options = isolated_function.options.gateway
if "serve" not in gateway_options and "exposed_port" not in gateway_options:
raise api.FalServerlessError(
Expand Down Expand Up @@ -307,6 +324,15 @@ def register_application(
console.print(f"URL: https://{user_id}-{id}.{gateway_host}")


@function_cli.command("run")
@click.argument("file_path", required=True)
@click.argument("function_name", required=True)
@click.pass_obj
def run(host: api.FalServerlessHost, file_path: str, function_name: str):
isolated_function = load_function_from(host, file_path, function_name)
isolated_function()


@function_cli.command("logs")
@click.option("--lines", default=100)
@click.option("--url", default=None)
Expand Down

0 comments on commit 6724ce2

Please sign in to comment.