diff --git a/projects/fal/pyproject.toml b/projects/fal/pyproject.toml index 6394f75b..d2c9cdf6 100644 --- a/projects/fal/pyproject.toml +++ b/projects/fal/pyproject.toml @@ -23,7 +23,7 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "isolate[build]>=0.15.0,<0.16.0", - "isolate-proto>=0.6.0,<0.7.0", + "isolate-proto>=0.6.4,<0.7.0", "grpcio==1.64.0", "dill==0.3.7", "cloudpickle==3.0.0", diff --git a/projects/fal/src/fal/api.py b/projects/fal/src/fal/api.py index 5005bd4b..96d88ac1 100644 --- a/projects/fal/src/fal/api.py +++ b/projects/fal/src/fal/api.py @@ -403,6 +403,7 @@ class FalServerlessHost(Host): "setup_function", "metadata", "request_timeout", + "startup_timeout", "_base_image", "_scheduler", "_scheduler_options", @@ -452,6 +453,7 @@ def register( max_multiplexing = options.host.get("max_multiplexing") exposed_port = options.get_exposed_port() request_timeout = options.host.get("request_timeout") + startup_timeout = options.host.get("startup_timeout") machine_requirements = MachineRequirements( machine_types=machine_type, # type: ignore num_gpus=options.host.get("num_gpus"), @@ -464,6 +466,7 @@ def register( max_concurrency=max_concurrency, min_concurrency=min_concurrency, request_timeout=request_timeout, + startup_timeout=startup_timeout, ) partial_func = _prepare_partial_func(func) @@ -526,7 +529,7 @@ def _run( exposed_port = options.get_exposed_port() setup_function = options.host.get("setup_function", None) request_timeout = options.host.get("request_timeout") - + startup_timeout = options.host.get("startup_timeout") machine_requirements = MachineRequirements( machine_types=machine_type, # type: ignore num_gpus=options.host.get("num_gpus"), @@ -539,6 +542,7 @@ def _run( max_concurrency=max_concurrency, min_concurrency=min_concurrency, request_timeout=request_timeout, + startup_timeout=startup_timeout, ) return_value = _UNSET @@ -705,6 +709,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, @@ -732,6 +737,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, @@ -809,6 +815,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, @@ -841,6 +848,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, @@ -867,6 +875,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, @@ -893,6 +902,7 @@ def function( max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING, min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY, request_timeout: int | None = None, + startup_timeout: int | None = None, setup_function: Callable[..., None] | None = None, _base_image: str | None = None, _scheduler: str | None = None, diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py index 25934ead..f8c4d84d 100644 --- a/projects/fal/src/fal/app.py +++ b/projects/fal/src/fal/app.py @@ -271,6 +271,7 @@ class App(fal.api.BaseServable): app_name: ClassVar[str] app_auth: ClassVar[Literal["private", "public", "shared"]] = "private" request_timeout: ClassVar[int | None] = None + startup_timeout: ClassVar[int | None] = None isolate_channel: async_grpc.Channel | None = None @@ -282,6 +283,9 @@ def __init_subclass__(cls, **kwargs): if cls.request_timeout is not None: cls.host_kwargs["request_timeout"] = cls.request_timeout + if cls.startup_timeout is not None: + cls.host_kwargs["startup_timeout"] = cls.startup_timeout + cls.app_name = getattr(cls, "app_name", app_name) if cls.__init__ is not App.__init__: diff --git a/projects/fal/src/fal/sdk.py b/projects/fal/src/fal/sdk.py index 6a631e82..2f6ed7af 100644 --- a/projects/fal/src/fal/sdk.py +++ b/projects/fal/src/fal/sdk.py @@ -402,6 +402,7 @@ class MachineRequirements: max_multiplexing: int | None = None min_concurrency: int | None = None request_timeout: int | None = None + startup_timeout: int | None = None def __post_init__(self): if isinstance(self.machine_types, str): @@ -519,6 +520,7 @@ def register( min_concurrency=machine_requirements.min_concurrency, max_multiplexing=machine_requirements.max_multiplexing, request_timeout=machine_requirements.request_timeout, + startup_timeout=machine_requirements.startup_timeout, ) else: wrapped_requirements = None @@ -614,6 +616,7 @@ def run( max_multiplexing=machine_requirements.max_multiplexing, min_concurrency=machine_requirements.min_concurrency, request_timeout=machine_requirements.request_timeout, + startup_timeout=machine_requirements.startup_timeout, ) else: wrapped_requirements = None