Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fal.App for multiple endpoints #27

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fal.api import FalServerlessHost, LocalHost, cached
from fal.api import function
from fal.api import function as isolated
from fal.app import App, endpoint, wrap_app
from fal.sdk import FalServerlessKeyCredentials
from fal.sync import sync_dir

Expand Down
162 changes: 162 additions & 0 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import inspect
import os
import fal.api
from fal.toolkit import mainify
from fastapi import FastAPI
from typing import Any, NamedTuple, Callable, TypeVar, ClassVar
from fal.logging import get_logger

EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
logger = get_logger(__name__)


def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
def initialize_and_serve():
app = cls()
app.serve()

try:
app = cls(_allow_init=True)
metadata = app.openapi()
except Exception as exc:
logger.warning("Failed to build OpenAPI specification for %s", cls.__name__)
metadata = {}

wrapper = fal.api.function(
"virtualenv",
requirements=cls.requirements,
machine_type=cls.machine_type,
**cls.host_kwargs,
**kwargs,
metadata=metadata,
serve=True,
)
return wrapper(initialize_and_serve).on(
serve=False,
exposed_port=8080,
)


@mainify
class RouteSignature(NamedTuple):
path: str


@mainify
class App:
requirements: ClassVar[list[str]] = []
machine_type: ClassVar[str] = "S"
host_kwargs: ClassVar[dict[str, Any]] = {}

def __init_subclass__(cls, **kwargs):
cls.host_kwargs = kwargs

if cls.__init__ is not App.__init__:
raise ValueError(
"App classes should not override __init__ directly. "
"Use setup() instead."
)

def __init__(self, *, _allow_init: bool = False):
if not _allow_init and not os.getenv("IS_ISOLATE_AGENT"):
raise NotImplementedError(
"Running apps through SDK is not implemented yet."
)

def setup(self):
"""Setup the application before serving."""

def serve(self) -> None:
import uvicorn

app = self._build_app()
self.setup()
uvicorn.run(app, host="0.0.0.0", port=8080)

def _build_app(self) -> FastAPI:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

_app = FastAPI()

_app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=("*"),
allow_methods=("*"),
allow_origins=("*"),
)

routes: dict[RouteSignature, Callable[..., Any]] = {
signature: endpoint
for _, endpoint in inspect.getmembers(self, inspect.ismethod)
if (signature := getattr(endpoint, "route_signature", None))
}
if not routes:
raise ValueError("An application must have at least one route!")

for signature, endpoint in routes.items():
_app.add_api_route(
signature.path,
endpoint,
name=endpoint.__name__,
methods=["POST"],
)

return _app

def openapi(self) -> dict[str, Any]:
"""
Build the OpenAPI specification for the served function.
Attach needed metadata for a better integration to fal.
"""
app = self._build_app()
spec = app.openapi()
self._mark_order_openapi(spec)
return spec

def _mark_order_openapi(self, spec: dict[str, Any]):
"""
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.

NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
"""

def mark_order(obj: dict[str, Any], key: str):
obj[f"x-fal-order-{key}"] = list(obj[key].keys())

mark_order(spec, "paths")

def order_schema_object(schema: dict[str, Any]):
"""
Mark the order of properties in the schema object.
They can have 'allOf', 'properties' or '$ref' key.
"""
if "allOf" in schema:
for sub_schema in schema["allOf"]:
order_schema_object(sub_schema)
if "properties" in schema:
mark_order(schema, "properties")

for key in spec["components"].get("schemas") or {}:
order_schema_object(spec["components"]["schemas"][key])

return spec
Comment on lines +78 to +146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this whole thing already existed for serve functions, can we merge them?



@mainify
def endpoint(path: str) -> Callable[[EndpointT], EndpointT]:
"""Designate the decorated function as an application endpoint."""

def marker_fn(callable: EndpointT) -> EndpointT:
if hasattr(callable, "route_signature"):
raise ValueError(
f"Can't set multiple routes for the same function: {callable.__name__}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't handle it yet, but in the future I want to be able to support multiple endpoints for a single function by stacking @fal.endpoint. this is done to reserve that use case (instead of treating it as an override which we would need to break)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha ok 👍

)

callable.route_signature = RouteSignature(path=path) # type: ignore
return callable

return marker_fn
40 changes: 33 additions & 7 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import click
import fal.auth as auth
import grpc
import fal
from fal import api, sdk
from fal.console import console
from fal.exceptions import ApplicationExceptionHandler
Expand Down Expand Up @@ -244,6 +245,28 @@ def function_cli(ctx, host: str, port: str):
ctx.obj = api.FalServerlessHost(f"{host}:{port}")


def load_function_from(
host: api.FalServerlessHost,
file_path: str,
function_name: str,
) -> api.IsolatedFunction:
import runpy

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

target = module[function_name]
if issubclass(target, fal.App):
target = fal.wrap_app(target, host=host)

if not isinstance(target, api.IsolatedFunction):
raise api.FalServerlessError(
f"Function '{function_name}' is not a fal.function or a fal.App"
)
return target


@function_cli.command("serve")
@click.option("--alias", default=None)
@click.option(
Expand All @@ -262,15 +285,9 @@ def register_application(
alias: str | None,
auth_mode: ALIAS_AUTH_TYPE,
):
import runpy

user_id = _get_user_id()

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

isolated_function: api.IsolatedFunction = module[function_name]
isolated_function = load_function_from(host, file_path, function_name)
gateway_options = isolated_function.options.gateway
if "serve" not in gateway_options and "exposed_port" not in gateway_options:
raise api.FalServerlessError(
Expand Down Expand Up @@ -307,6 +324,15 @@ def register_application(
console.print(f"URL: https://{user_id}-{id}.{gateway_host}")


@function_cli.command("run")
@click.argument("file_path", required=True)
@click.argument("function_name", required=True)
@click.pass_obj
def run(host: api.FalServerlessHost, file_path: str, function_name: str):
isolated_function = load_function_from(host, file_path, function_name)
isolated_function()


@function_cli.command("logs")
@click.option("--lines", default=100)
@click.option("--url", default=None)
Expand Down
58 changes: 58 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Input(BaseModel):
wait_time: int = 0


class StatefulInput(BaseModel):
value: int


class Output(BaseModel):
result: int

Expand Down Expand Up @@ -77,6 +81,28 @@ def subtract(input: Input) -> Output:
run(app, host="0.0.0.0", port=8080)


class StatefulAdditionApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "S"

def setup(self):
self.counter = 0

@fal.endpoint("/reset")
def reset(self) -> Output:
self.counter = 0
return Output(result=self.counter)

@fal.endpoint("/increment")
def increment(self, input: StatefulInput) -> Output:
self.counter += input.value
return Output(result=self.counter)

@fal.endpoint("/decrement")
def decrement(self, input: StatefulInput) -> Output:
self.counter -= input.value
return Output(result=self.counter)


@pytest.fixture(scope="module")
def aliased_app() -> Generator[tuple[str, str], None, None]:
# Create a temporary app, register it, and return the ID of it.
Expand Down Expand Up @@ -122,6 +148,21 @@ def test_fastapi_app():
yield f"{user_id}-{app_revision}"


@pytest.fixture(scope="module")
def test_stateful_app():
# Create a temporary app, register it, and return the ID of it.

from fal.cli import _get_user_id

app = fal.wrap_app(StatefulAdditionApp)
app_revision = app.host.register(
func=app.func,
options=app.options,
)
user_id = _get_user_id()
yield f"{user_id}-{app_revision}"


def test_app_client(test_app: str):
response = apps.run(test_app, arguments={"lhs": 1, "rhs": 2})
assert response["result"] == 3
Expand All @@ -130,6 +171,23 @@ def test_app_client(test_app: str):
assert response["result"] == 5


def test_stateful_app_client(test_stateful_app: str):
response = apps.run(test_stateful_app, arguments={}, path="/reset")
assert response["result"] == 0

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/increment")
assert response["result"] == 1

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/increment")
assert response["result"] == 3

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/decrement")
assert response["result"] == 2

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/decrement")
assert response["result"] == 0


def test_app_client_async(test_app: str):
request_handle = apps.submit(test_app, arguments={"lhs": 1, "rhs": 2})
assert request_handle.get() == {"result": 3}
Expand Down
5 changes: 5 additions & 0 deletions projects/fal/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@


def test_missing_dependencies_nested_server_error(isolated_client):
from fal import _serialization

_serialization._PACKAGES.clear()
_serialization._MODULES.clear()

@isolated_client()
def test1():
return "hello"
Expand Down