From d55dd1583d7ced17db9895308557f45698529af1 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Wed, 10 Jan 2024 00:33:28 +0300 Subject: [PATCH 1/3] wip: feat: `fal.App` for multiple endpoints --- projects/fal/src/fal/__init__.py | 1 + projects/fal/src/fal/app.py | 147 +++++++++++++++++++++++++++++++ projects/fal/src/fal/cli.py | 40 +++++++-- 3 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 projects/fal/src/fal/app.py diff --git a/projects/fal/src/fal/__init__.py b/projects/fal/src/fal/__init__.py index cfc5cc49..cdbf2412 100644 --- a/projects/fal/src/fal/__init__.py +++ b/projects/fal/src/fal/__init__.py @@ -6,6 +6,7 @@ from fal.api import FalServerlessHost, LocalHost, cached from fal.api import function from fal.api import function as isolated +from fal.app import App, endpoint, wrap_app from fal.sdk import FalServerlessKeyCredentials from fal.sync import sync_dir diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py new file mode 100644 index 00000000..4852b524 --- /dev/null +++ b/projects/fal/src/fal/app.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import inspect +import os +import fal.api +from fal.toolkit import mainify +from fastapi import FastAPI +from typing import Any, NamedTuple, Callable, TypeVar, ClassVar + + +EndpointT = TypeVar("EndpointT", bound=Callable[..., Any]) + + +def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction: + def initialize_and_serve(): + app = cls() + app.serve() + + wrapper = fal.api.function( + "virtualenv", + requirements=cls.requirements, + machine_type=cls.machine_type, + **cls.host_kwargs, + **kwargs, + serve=True, + ) + return wrapper(initialize_and_serve).on( + serve=False, + exposed_port=8080, + ) + + +@mainify +class RouteSignature(NamedTuple): + path: str + + +@mainify +class App: + requirements: ClassVar[list[str]] = [] + machine_type: ClassVar[str] = "S" + host_kwargs: ClassVar[dict[str, Any]] = {} + + def __init_subclass__(cls, **kwargs): + cls.host_kwargs = kwargs + + def __init__(self): + if not os.getenv("IS_ISOLATE_AGENT"): + raise NotImplementedError( + "Running apps through SDK is not implemented yet." + ) + + def setup(self): + """Setup the application before serving.""" + + def serve(self) -> None: + import uvicorn + + app = self._build_app() + self.setup() + uvicorn.run(app, host="0.0.0.0", port=8080) + + def _build_app(self) -> FastAPI: + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + + _app = FastAPI() + + _app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_headers=("*"), + allow_methods=("*"), + allow_origins=("*"), + ) + + routes: dict[RouteSignature, Callable[..., Any]] = { + signature: endpoint + for _, endpoint in inspect.getmembers(self, inspect.ismethod) + if (signature := getattr(endpoint, "route_signature", None)) + } + if not routes: + raise ValueError("An application must have at least one route!") + + for signature, endpoint in routes.items(): + _app.add_api_route( + signature.path, + endpoint, + name=endpoint.__name__, + methods=["POST"], + ) + + return _app + + def openapi(self) -> dict[str, Any]: + """ + Build the OpenAPI specification for the served function. + Attach needed metadata for a better integration to fal. + """ + app = self._build_app() + spec = app.openapi() + self._mark_order_openapi(spec) + return spec + + def _mark_order_openapi(self, spec: dict[str, Any]): + """ + Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI. + + NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties. + """ + + def mark_order(obj: dict[str, Any], key: str): + obj[f"x-fal-order-{key}"] = list(obj[key].keys()) + + mark_order(spec, "paths") + + def order_schema_object(schema: dict[str, Any]): + """ + Mark the order of properties in the schema object. + They can have 'allOf', 'properties' or '$ref' key. + """ + if "allOf" in schema: + for sub_schema in schema["allOf"]: + order_schema_object(sub_schema) + if "properties" in schema: + mark_order(schema, "properties") + + for key in spec["components"].get("schemas") or {}: + order_schema_object(spec["components"]["schemas"][key]) + + return spec + + +@mainify +def endpoint(path: str) -> Callable[[EndpointT], EndpointT]: + """Designate the decorated function as an application endpoint.""" + + def marker_fn(callable: EndpointT) -> EndpointT: + if hasattr(callable, "route_signature"): + raise ValueError( + f"Can't set multiple routes for the same function: {callable.__name__}" + ) + + callable.route_signature = RouteSignature(path=path) # type: ignore + return callable + + return marker_fn diff --git a/projects/fal/src/fal/cli.py b/projects/fal/src/fal/cli.py index 0ec2feb7..e61cd358 100644 --- a/projects/fal/src/fal/cli.py +++ b/projects/fal/src/fal/cli.py @@ -10,6 +10,7 @@ import click import fal.auth as auth import grpc +import fal from fal import api, sdk from fal.console import console from fal.exceptions import ApplicationExceptionHandler @@ -244,6 +245,28 @@ def function_cli(ctx, host: str, port: str): ctx.obj = api.FalServerlessHost(f"{host}:{port}") +def load_function_from( + host: api.FalServerlessHost, + file_path: str, + function_name: str, +) -> api.IsolatedFunction: + import runpy + + module = runpy.run_path(file_path) + if function_name not in module: + raise api.FalServerlessError(f"Function '{function_name}' not found in module") + + target = module[function_name] + if issubclass(target, fal.App): + target = fal.wrap_app(target, host=host) + + if not isinstance(target, api.IsolatedFunction): + raise api.FalServerlessError( + f"Function '{function_name}' is not a fal.function or a fal.App" + ) + return target + + @function_cli.command("serve") @click.option("--alias", default=None) @click.option( @@ -262,15 +285,9 @@ def register_application( alias: str | None, auth_mode: ALIAS_AUTH_TYPE, ): - import runpy - user_id = _get_user_id() - module = runpy.run_path(file_path) - if function_name not in module: - raise api.FalServerlessError(f"Function '{function_name}' not found in module") - - isolated_function: api.IsolatedFunction = module[function_name] + isolated_function = load_function_from(host, file_path, function_name) gateway_options = isolated_function.options.gateway if "serve" not in gateway_options and "exposed_port" not in gateway_options: raise api.FalServerlessError( @@ -307,6 +324,15 @@ def register_application( console.print(f"URL: https://{user_id}-{id}.{gateway_host}") +@function_cli.command("run") +@click.argument("file_path", required=True) +@click.argument("function_name", required=True) +@click.pass_obj +def run(host: api.FalServerlessHost, file_path: str, function_name: str): + isolated_function = load_function_from(host, file_path, function_name) + isolated_function() + + @function_cli.command("logs") @click.option("--lines", default=100) @click.option("--url", default=None) From 1ed49e68ba515cc59245d8fd0d0b54912bc1f23d Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Wed, 10 Jan 2024 18:26:26 +0300 Subject: [PATCH 2/3] fix: build the metadata --- projects/fal/src/fal/app.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py index 4852b524..411cf286 100644 --- a/projects/fal/src/fal/app.py +++ b/projects/fal/src/fal/app.py @@ -6,9 +6,10 @@ from fal.toolkit import mainify from fastapi import FastAPI from typing import Any, NamedTuple, Callable, TypeVar, ClassVar - +from fal.logging import get_logger EndpointT = TypeVar("EndpointT", bound=Callable[..., Any]) +logger = get_logger(__name__) def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction: @@ -16,12 +17,20 @@ def initialize_and_serve(): app = cls() app.serve() + try: + app = cls(_allow_init=True) + metadata = app.openapi() + except Exception as exc: + logger.warning("Failed to build OpenAPI specification for %s", cls.__name__) + metadata = {} + wrapper = fal.api.function( "virtualenv", requirements=cls.requirements, machine_type=cls.machine_type, **cls.host_kwargs, **kwargs, + metadata=metadata, serve=True, ) return wrapper(initialize_and_serve).on( @@ -44,8 +53,14 @@ class App: def __init_subclass__(cls, **kwargs): cls.host_kwargs = kwargs - def __init__(self): - if not os.getenv("IS_ISOLATE_AGENT"): + if cls.__init__ is not App.__init__: + raise ValueError( + "App classes should not override __init__ directly. " + "Use setup() instead." + ) + + def __init__(self, *, _allow_init: bool = False): + if not _allow_init and not os.getenv("IS_ISOLATE_AGENT"): raise NotImplementedError( "Running apps through SDK is not implemented yet." ) From 6340ecbcaffebb00d0c4fbbdb62ce26537c898b6 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Wed, 10 Jan 2024 18:33:37 +0300 Subject: [PATCH 3/3] add tests --- projects/fal/tests/test_apps.py | 58 ++++++++++++++++++++++++++++ projects/fal/tests/test_stability.py | 5 +++ 2 files changed, 63 insertions(+) diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index 3424a52a..3fb4155c 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -17,6 +17,10 @@ class Input(BaseModel): wait_time: int = 0 +class StatefulInput(BaseModel): + value: int + + class Output(BaseModel): result: int @@ -77,6 +81,28 @@ def subtract(input: Input) -> Output: run(app, host="0.0.0.0", port=8080) +class StatefulAdditionApp(fal.App, keep_alive=300, max_concurrency=1): + machine_type = "S" + + def setup(self): + self.counter = 0 + + @fal.endpoint("/reset") + def reset(self) -> Output: + self.counter = 0 + return Output(result=self.counter) + + @fal.endpoint("/increment") + def increment(self, input: StatefulInput) -> Output: + self.counter += input.value + return Output(result=self.counter) + + @fal.endpoint("/decrement") + def decrement(self, input: StatefulInput) -> Output: + self.counter -= input.value + return Output(result=self.counter) + + @pytest.fixture(scope="module") def aliased_app() -> Generator[tuple[str, str], None, None]: # Create a temporary app, register it, and return the ID of it. @@ -122,6 +148,21 @@ def test_fastapi_app(): yield f"{user_id}-{app_revision}" +@pytest.fixture(scope="module") +def test_stateful_app(): + # Create a temporary app, register it, and return the ID of it. + + from fal.cli import _get_user_id + + app = fal.wrap_app(StatefulAdditionApp) + app_revision = app.host.register( + func=app.func, + options=app.options, + ) + user_id = _get_user_id() + yield f"{user_id}-{app_revision}" + + def test_app_client(test_app: str): response = apps.run(test_app, arguments={"lhs": 1, "rhs": 2}) assert response["result"] == 3 @@ -130,6 +171,23 @@ def test_app_client(test_app: str): assert response["result"] == 5 +def test_stateful_app_client(test_stateful_app: str): + response = apps.run(test_stateful_app, arguments={}, path="/reset") + assert response["result"] == 0 + + response = apps.run(test_stateful_app, arguments={"value": 1}, path="/increment") + assert response["result"] == 1 + + response = apps.run(test_stateful_app, arguments={"value": 2}, path="/increment") + assert response["result"] == 3 + + response = apps.run(test_stateful_app, arguments={"value": 1}, path="/decrement") + assert response["result"] == 2 + + response = apps.run(test_stateful_app, arguments={"value": 2}, path="/decrement") + assert response["result"] == 0 + + def test_app_client_async(test_app: str): request_handle = apps.submit(test_app, arguments={"lhs": 1, "rhs": 2}) assert request_handle.get() == {"result": 3} diff --git a/projects/fal/tests/test_stability.py b/projects/fal/tests/test_stability.py index 47100a53..c023e530 100644 --- a/projects/fal/tests/test_stability.py +++ b/projects/fal/tests/test_stability.py @@ -11,6 +11,11 @@ def test_missing_dependencies_nested_server_error(isolated_client): + from fal import _serialization + + _serialization._PACKAGES.clear() + _serialization._MODULES.clear() + @isolated_client() def test1(): return "hello"