Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical committed Jan 10, 2024
1 parent 0faa28b commit 43ea98a
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Input(BaseModel):
wait_time: int = 0


class StatefulInput(BaseModel):
value: int


class Output(BaseModel):
result: int

Expand All @@ -36,6 +40,28 @@ def addition_app(input: Input) -> Output:
return Output(result=input.lhs + input.rhs)


class StatefulAdditionApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "S"

def setup(self):
self.counter = 0

@fal.endpoint("/reset")
def reset(self) -> Output:
self.counter = 0
return Output(result=self.counter)

@fal.endpoint("/increment")
def increment(self, input: StatefulInput) -> Output:
self.counter += input.value
return Output(result=self.counter)

@fal.endpoint("/decrement")
def decrement(self, input: StatefulInput) -> Output:
self.counter -= input.value
return Output(result=self.counter)


@pytest.fixture(scope="module")
def aliased_app() -> Generator[tuple[str, str], None, None]:
# Create a temporary app, register it, and return the ID of it.
Expand Down Expand Up @@ -67,6 +93,21 @@ def test_app():
yield f"{user_id}-{app_revision}"


@pytest.fixture(scope="module")
def test_stateful_app():
# Create a temporary app, register it, and return the ID of it.

from fal.cli import _get_user_id

app = fal.wrap_app(StatefulAdditionApp)
app_revision = app.host.register(
func=app.func,
options=app.options,
)
user_id = _get_user_id()
yield f"{user_id}-{app_revision}"


def test_app_client(test_app: str):
response = apps.run(test_app, arguments={"lhs": 1, "rhs": 2})
assert response["result"] == 3
Expand All @@ -75,6 +116,23 @@ def test_app_client(test_app: str):
assert response["result"] == 5


def test_stateful_app_client(test_stateful_app: str):
response = apps.run(test_stateful_app, path="/reset")
assert response["result"] == 0

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/increment")
assert response["result"] == 1

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/increment")
assert response["result"] == 3

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/decrement")
assert response["result"] == 2

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/decrement")
assert response["result"] == 0


def test_app_client_async(test_app: str):
request_handle = apps.submit(test_app, arguments={"lhs": 1, "rhs": 2})
assert request_handle.get() == {"result": 3}
Expand Down

0 comments on commit 43ea98a

Please sign in to comment.