Skip to content

Commit

Permalink
feat(fal): introduce startup_timeout (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop authored Feb 13, 2025
1 parent fca3541 commit e541561
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion projects/fal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"isolate[build]>=0.15.0,<0.16.0",
"isolate-proto>=0.6.0,<0.7.0",
"isolate-proto>=0.6.4,<0.7.0",
"grpcio==1.64.0",
"dill==0.3.7",
"cloudpickle==3.0.0",
Expand Down
12 changes: 11 additions & 1 deletion projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ class FalServerlessHost(Host):
"setup_function",
"metadata",
"request_timeout",
"startup_timeout",
"_base_image",
"_scheduler",
"_scheduler_options",
Expand Down Expand Up @@ -452,6 +453,7 @@ def register(
max_multiplexing = options.host.get("max_multiplexing")
exposed_port = options.get_exposed_port()
request_timeout = options.host.get("request_timeout")
startup_timeout = options.host.get("startup_timeout")
machine_requirements = MachineRequirements(
machine_types=machine_type, # type: ignore
num_gpus=options.host.get("num_gpus"),
Expand All @@ -464,6 +466,7 @@ def register(
max_concurrency=max_concurrency,
min_concurrency=min_concurrency,
request_timeout=request_timeout,
startup_timeout=startup_timeout,
)

partial_func = _prepare_partial_func(func)
Expand Down Expand Up @@ -526,7 +529,7 @@ def _run(
exposed_port = options.get_exposed_port()
setup_function = options.host.get("setup_function", None)
request_timeout = options.host.get("request_timeout")

startup_timeout = options.host.get("startup_timeout")
machine_requirements = MachineRequirements(
machine_types=machine_type, # type: ignore
num_gpus=options.host.get("num_gpus"),
Expand All @@ -539,6 +542,7 @@ def _run(
max_concurrency=max_concurrency,
min_concurrency=min_concurrency,
request_timeout=request_timeout,
startup_timeout=startup_timeout,
)

return_value = _UNSET
Expand Down Expand Up @@ -705,6 +709,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand Down Expand Up @@ -732,6 +737,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand Down Expand Up @@ -809,6 +815,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand Down Expand Up @@ -841,6 +848,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand All @@ -867,6 +875,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand All @@ -893,6 +902,7 @@ def function(
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
request_timeout: int | None = None,
startup_timeout: int | None = None,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class App(fal.api.BaseServable):
app_name: ClassVar[str]
app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
request_timeout: ClassVar[int | None] = None
startup_timeout: ClassVar[int | None] = None

isolate_channel: async_grpc.Channel | None = None

Expand All @@ -282,6 +283,9 @@ def __init_subclass__(cls, **kwargs):
if cls.request_timeout is not None:
cls.host_kwargs["request_timeout"] = cls.request_timeout

if cls.startup_timeout is not None:
cls.host_kwargs["startup_timeout"] = cls.startup_timeout

cls.app_name = getattr(cls, "app_name", app_name)

if cls.__init__ is not App.__init__:
Expand Down
3 changes: 3 additions & 0 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class MachineRequirements:
max_multiplexing: int | None = None
min_concurrency: int | None = None
request_timeout: int | None = None
startup_timeout: int | None = None

def __post_init__(self):
if isinstance(self.machine_types, str):
Expand Down Expand Up @@ -519,6 +520,7 @@ def register(
min_concurrency=machine_requirements.min_concurrency,
max_multiplexing=machine_requirements.max_multiplexing,
request_timeout=machine_requirements.request_timeout,
startup_timeout=machine_requirements.startup_timeout,
)
else:
wrapped_requirements = None
Expand Down Expand Up @@ -614,6 +616,7 @@ def run(
max_multiplexing=machine_requirements.max_multiplexing,
min_concurrency=machine_requirements.min_concurrency,
request_timeout=machine_requirements.request_timeout,
startup_timeout=machine_requirements.startup_timeout,
)
else:
wrapped_requirements = None
Expand Down

0 comments on commit e541561

Please sign in to comment.