diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index 22f683c6..04448731 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -17,6 +17,10 @@ class Input(BaseModel): wait_time: int = 0 +class StatefulInput(BaseModel): + value: int + + class Output(BaseModel): result: int @@ -36,6 +40,28 @@ def addition_app(input: Input) -> Output: return Output(result=input.lhs + input.rhs) +class StatefulAdditionApp(fal.App, keep_alive=300, max_concurrency=1): + machine_type = "S" + + def setup(self): + self.counter = 0 + + @fal.endpoint("/reset") + def reset(self) -> Output: + self.counter = 0 + return Output(result=self.counter) + + @fal.endpoint("/increment") + def increment(self, input: StatefulInput) -> Output: + self.counter += input.value + return Output(result=self.counter) + + @fal.endpoint("/decrement") + def decrement(self, input: StatefulInput) -> Output: + self.counter -= input.value + return Output(result=self.counter) + + @pytest.fixture(scope="module") def aliased_app() -> Generator[tuple[str, str], None, None]: # Create a temporary app, register it, and return the ID of it. @@ -67,6 +93,21 @@ def test_app(): yield f"{user_id}-{app_revision}" +@pytest.fixture(scope="module") +def test_stateful_app(): + # Create a temporary app, register it, and return the ID of it. + + from fal.cli import _get_user_id + + app = fal.wrap_app(StatefulAdditionApp) + app_revision = app.host.register( + func=app.func, + options=app.options, + ) + user_id = _get_user_id() + yield f"{user_id}-{app_revision}" + + def test_app_client(test_app: str): response = apps.run(test_app, arguments={"lhs": 1, "rhs": 2}) assert response["result"] == 3 @@ -75,6 +116,23 @@ def test_app_client(test_app: str): assert response["result"] == 5 +def test_stateful_app_client(test_stateful_app: str): + response = apps.run(test_stateful_app, path="/reset") + assert response["result"] == 0 + + response = apps.run(test_stateful_app, arguments={"value": 1}, path="/increment") + assert response["result"] == 1 + + response = apps.run(test_stateful_app, arguments={"value": 2}, path="/increment") + assert response["result"] == 3 + + response = apps.run(test_stateful_app, arguments={"value": 1}, path="/decrement") + assert response["result"] == 2 + + response = apps.run(test_stateful_app, arguments={"value": 2}, path="/decrement") + assert response["result"] == 0 + + def test_app_client_async(test_app: str): request_handle = apps.submit(test_app, arguments={"lhs": 1, "rhs": 2}) assert request_handle.get() == {"result": 3}