From a9148a1352e9f0c8441d9bb130f27239ea99b5c8 Mon Sep 17 00:00:00 2001 From: Chris O'Hara <cohara87@gmail.com> Date: Mon, 4 Mar 2024 10:06:49 +1000 Subject: [PATCH] Extract a base class to better handle primitive functions --- src/dispatch/function.py | 146 +++++++++++++++++++++------------------ 1 file changed, 78 insertions(+), 68 deletions(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index a04a3f1a..03df30b2 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -33,16 +33,8 @@ """ -P = ParamSpec("P") -T = TypeVar("T") - - -class Function(Generic[P, T]): - """Callable wrapper around a function meant to be used throughout the - Dispatch Python SDK. - """ - - __slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func") +class PrimitiveFunction: + __slots__ = ("_endpoint", "_client", "_name", "_primitive_func") def __init__( self, @@ -50,23 +42,11 @@ def __init__( client: Client, name: str, primitive_func: PrimitiveFunctionType, - func: Callable[..., Any] | None, ): self._endpoint = endpoint self._client = client self._name = name self._primitive_func = primitive_func - self._func: Callable[P, Coroutine[Any, Any, T]] | None = ( - durable(self._call_async) if func else None - ) - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: - if self._func is None: - raise ValueError("cannot call a primitive function directly") - return self._func(*args, **kwargs) - - def _primitive_call(self, input: Input) -> Output: - return self._primitive_func(input) @property def endpoint(self) -> str: @@ -76,8 +56,62 @@ def endpoint(self) -> str: def name(self) -> str: return self._name + def _primitive_call(self, input: Input) -> Output: + return self._primitive_func(input) + + def _primitive_dispatch(self, input: Any = None) -> DispatchID: + [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)]) + return dispatch_id + + def _build_primitive_call( + self, input: Any, correlation_id: int | None = None + ) -> Call: + return Call( + correlation_id=correlation_id, + endpoint=self.endpoint, + function=self.name, + input=input, + ) + + +P = ParamSpec("P") +T = TypeVar("T") + + +class Function(PrimitiveFunction, Generic[P, T]): + """Callable wrapper around a function meant to be used throughout the + Dispatch Python SDK. + """ + + __slots__ = ("_func_indirect",) + + def __init__( + self, + endpoint: str, + client: Client, + name: str, + primitive_func: PrimitiveFunctionType, + func: Callable, + ): + PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func) + + self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable( + self._call_async + ) + + async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T: + return await dispatch.coroutine.call( + self.build_call(*args, **kwargs, correlation_id=None) + ) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]: + """Call the function asynchronously (through Dispatch), and return a + coroutine that can be awaited to retrieve the call result.""" + return self._func_indirect(*args, **kwargs) + def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: - """Dispatch a call to the function. + """Dispatch an asynchronous call to the function without + waiting for a result. The Registry this function was registered with must be initialized with a Client / api_key for this call facility to be available. @@ -94,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: """ return self._primitive_dispatch(Arguments(args, kwargs)) - def _primitive_dispatch(self, input: Any = None) -> DispatchID: - [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)]) - return dispatch_id - - async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T: - """Asynchronously call the function from a @dispatch.function.""" - return await dispatch.coroutine.call( - self.build_call(*args, **kwargs, correlation_id=None) - ) - def build_call( self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs ) -> Call: @@ -123,16 +147,6 @@ def build_call( Arguments(args, kwargs), correlation_id=correlation_id ) - def _build_primitive_call( - self, input: Any, correlation_id: int | None = None - ) -> Call: - return Call( - correlation_id=correlation_id, - endpoint=self.endpoint, - function=self.name, - input=input, - ) - class Registry: """Registry of local functions.""" @@ -147,7 +161,7 @@ def __init__(self, endpoint: str, client: Client): client: Client for the Dispatch API. Used to dispatch calls to local functions. """ - self._functions: Dict[str, Function] = {} + self._functions: Dict[str, PrimitiveFunction] = {} self._endpoint = endpoint self._client = client @@ -166,10 +180,6 @@ def function(self, func): logger.info("registering coroutine: %s", func.__qualname__) return self._register_coroutine(func) - def primitive_function(self, func: PrimitiveFunctionType) -> Function: - """Decorator that registers primitive functions.""" - return self._register_primitive_function(func) - def _register_function(self, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @@ -184,7 +194,8 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def _register_coroutine( self, func: Callable[P, Coroutine[Any, Any, T]] ) -> Function[P, T]: - logger.info("registering coroutine: %s", func.__qualname__) + name = func.__qualname__ + logger.info("registering coroutine: %s", name) func = durable(func) @@ -192,32 +203,31 @@ def _register_coroutine( def primitive_func(input: Input) -> Output: return OneShotScheduler(func).run(input) - primitive_func.__qualname__ = f"{func.__qualname__}_primitive" + primitive_func.__qualname__ = f"{name}_primitive" primitive_func = durable(primitive_func) - return self._register(primitive_func, func) + wrapped_func = Function[P, T]( + self._endpoint, self._client, name, primitive_func, func + ) + self._register(name, wrapped_func) + return wrapped_func - def _register_primitive_function( + def primitive_function( self, primitive_func: PrimitiveFunctionType - ) -> Function[P, T]: - logger.info("registering primitive function: %s", primitive_func.__qualname__) - return self._register(primitive_func, func=None) + ) -> PrimitiveFunction: + """Decorator that registers primitive functions.""" + name = primitive_func.__qualname__ + logger.info("registering primitive function: %s", name) + wrapped_func = PrimitiveFunction( + self._endpoint, self._client, name, primitive_func + ) + self._register(name, wrapped_func) + return wrapped_func - def _register( - self, - primitive_func: PrimitiveFunctionType, - func: Callable[P, Coroutine[Any, Any, T]] | None, - ) -> Function[P, T]: - name = func.__qualname__ if func else primitive_func.__qualname__ + def _register(self, name: str, wrapped_func: PrimitiveFunction): if name in self._functions: - raise ValueError( - f"function or coroutine already registered with name '{name}'" - ) - wrapped_func = Function[P, T]( - self._endpoint, self._client, name, primitive_func, func - ) + raise ValueError(f"function already registered with name '{name}'") self._functions[name] = wrapped_func - return wrapped_func def set_client(self, client: Client): """Set the Client instance used to dispatch calls to local functions."""