Skip to content

Commit

Permalink
Merge pull request #81 from stealthrocket/fix-coroutine-decoration
Browse files Browse the repository at this point in the history
fix coroutine decoration
  • Loading branch information
achille-roussel authored Feb 20, 2024
2 parents 46b4bf9 + fe5e04a commit 20b8f6c
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ def __init__(
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
coroutine: bool = False,
):
self._endpoint = endpoint
self._client = client
self._name = name
self._primitive_func = primitive_func
# FIXME: is there a way to decorate the function at the definition
# without making it a class method?
if inspect.iscoroutinefunction(func):
self._func = durable(self._call_async)
else:
self._func = func
self._func = durable(self._call_async) if coroutine else func

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
Expand Down Expand Up @@ -204,7 +202,7 @@ def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
primitive_func = durable(primitive_func)

return self._register(func, primitive_func)
return self._register(primitive_func, func, coroutine=False)

def _register_coroutine(self, func: Callable) -> Function:
logger.info("registering coroutine: %s", func.__qualname__)
Expand All @@ -218,22 +216,22 @@ def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
primitive_func = durable(primitive_func)

return self._register(func, primitive_func)
return self._register(primitive_func, func, coroutine=True)

def _register_primitive_function(self, func: PrimitiveFunctionType) -> Function:
logger.info("registering primitive function: %s", func.__qualname__)
return self._register(func, func)
return self._register(func, func, coroutine=inspect.iscoroutinefunction(func))

def _register(
self, func: Callable, primitive_func: PrimitiveFunctionType
self, primitive_func: PrimitiveFunctionType, func: Callable, coroutine: bool
) -> Function:
name = func.__qualname__
if name in self._functions:
raise ValueError(
f"function or coroutine already registered with name '{name}'"
)
wrapped_func = Function(
self._endpoint, self._client, name, primitive_func, func
self._endpoint, self._client, name, primitive_func, func, coroutine
)
self._functions[name] = wrapped_func
return wrapped_func

0 comments on commit 20b8f6c

Please sign in to comment.