diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 798af504..c996ff20 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -59,6 +59,7 @@ def __init__( name: str, primitive_func: PrimitiveFunctionType, func: Callable, + coroutine: bool = False, ): self._endpoint = endpoint self._client = client @@ -66,10 +67,7 @@ def __init__( self._primitive_func = primitive_func # FIXME: is there a way to decorate the function at the definition # without making it a class method? - if inspect.iscoroutinefunction(func): - self._func = durable(self._call_async) - else: - self._func = func + self._func = durable(self._call_async) if coroutine else func def __call__(self, *args, **kwargs): return self._func(*args, **kwargs) @@ -204,7 +202,7 @@ def primitive_func(input: Input) -> Output: primitive_func.__qualname__ = f"{func.__qualname__}_primitive" primitive_func = durable(primitive_func) - return self._register(func, primitive_func) + return self._register(primitive_func, func, coroutine=False) def _register_coroutine(self, func: Callable) -> Function: logger.info("registering coroutine: %s", func.__qualname__) @@ -218,14 +216,14 @@ def primitive_func(input: Input) -> Output: primitive_func.__qualname__ = f"{func.__qualname__}_primitive" primitive_func = durable(primitive_func) - return self._register(func, primitive_func) + return self._register(primitive_func, func, coroutine=True) def _register_primitive_function(self, func: PrimitiveFunctionType) -> Function: logger.info("registering primitive function: %s", func.__qualname__) - return self._register(func, func) + return self._register(func, func, coroutine=inspect.iscoroutinefunction(func)) def _register( - self, func: Callable, primitive_func: PrimitiveFunctionType + self, primitive_func: PrimitiveFunctionType, func: Callable, coroutine: bool ) -> Function: name = func.__qualname__ if name in self._functions: @@ -233,7 +231,7 @@ def _register( f"function or coroutine already registered with name '{name}'" ) wrapped_func = Function( - self._endpoint, self._client, name, primitive_func, func + self._endpoint, self._client, name, primitive_func, func, coroutine ) self._functions[name] = wrapped_func return wrapped_func