From a9148a1352e9f0c8441d9bb130f27239ea99b5c8 Mon Sep 17 00:00:00 2001
From: Chris O'Hara <cohara87@gmail.com>
Date: Mon, 4 Mar 2024 10:06:49 +1000
Subject: [PATCH] Extract a base class to better handle primitive functions

---
 src/dispatch/function.py | 146 +++++++++++++++++++++------------------
 1 file changed, 78 insertions(+), 68 deletions(-)

diff --git a/src/dispatch/function.py b/src/dispatch/function.py
index a04a3f1a..03df30b2 100644
--- a/src/dispatch/function.py
+++ b/src/dispatch/function.py
@@ -33,16 +33,8 @@
 """
 
 
-P = ParamSpec("P")
-T = TypeVar("T")
-
-
-class Function(Generic[P, T]):
-    """Callable wrapper around a function meant to be used throughout the
-    Dispatch Python SDK.
-    """
-
-    __slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
+class PrimitiveFunction:
+    __slots__ = ("_endpoint", "_client", "_name", "_primitive_func")
 
     def __init__(
         self,
@@ -50,23 +42,11 @@ def __init__(
         client: Client,
         name: str,
         primitive_func: PrimitiveFunctionType,
-        func: Callable[..., Any] | None,
     ):
         self._endpoint = endpoint
         self._client = client
         self._name = name
         self._primitive_func = primitive_func
-        self._func: Callable[P, Coroutine[Any, Any, T]] | None = (
-            durable(self._call_async) if func else None
-        )
-
-    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
-        if self._func is None:
-            raise ValueError("cannot call a primitive function directly")
-        return self._func(*args, **kwargs)
-
-    def _primitive_call(self, input: Input) -> Output:
-        return self._primitive_func(input)
 
     @property
     def endpoint(self) -> str:
@@ -76,8 +56,62 @@ def endpoint(self) -> str:
     def name(self) -> str:
         return self._name
 
+    def _primitive_call(self, input: Input) -> Output:
+        return self._primitive_func(input)
+
+    def _primitive_dispatch(self, input: Any = None) -> DispatchID:
+        [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
+        return dispatch_id
+
+    def _build_primitive_call(
+        self, input: Any, correlation_id: int | None = None
+    ) -> Call:
+        return Call(
+            correlation_id=correlation_id,
+            endpoint=self.endpoint,
+            function=self.name,
+            input=input,
+        )
+
+
+P = ParamSpec("P")
+T = TypeVar("T")
+
+
+class Function(PrimitiveFunction, Generic[P, T]):
+    """Callable wrapper around a function meant to be used throughout the
+    Dispatch Python SDK.
+    """
+
+    __slots__ = ("_func_indirect",)
+
+    def __init__(
+        self,
+        endpoint: str,
+        client: Client,
+        name: str,
+        primitive_func: PrimitiveFunctionType,
+        func: Callable,
+    ):
+        PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)
+
+        self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable(
+            self._call_async
+        )
+
+    async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
+        return await dispatch.coroutine.call(
+            self.build_call(*args, **kwargs, correlation_id=None)
+        )
+
+    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
+        """Call the function asynchronously (through Dispatch), and return a
+        coroutine that can be awaited to retrieve the call result."""
+        return self._func_indirect(*args, **kwargs)
+
     def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
-        """Dispatch a call to the function.
+        """Dispatch an asynchronous call to the function without
+        waiting for a result.
 
         The Registry this function was registered with must be initialized
         with a Client / api_key for this call facility to be available.
@@ -94,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
         """
         return self._primitive_dispatch(Arguments(args, kwargs))
 
-    def _primitive_dispatch(self, input: Any = None) -> DispatchID:
-        [dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
-        return dispatch_id
-
-    async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
-        """Asynchronously call the function from a @dispatch.function."""
-        return await dispatch.coroutine.call(
-            self.build_call(*args, **kwargs, correlation_id=None)
-        )
-
     def build_call(
         self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
     ) -> Call:
@@ -123,16 +147,6 @@ def build_call(
             Arguments(args, kwargs), correlation_id=correlation_id
         )
 
-    def _build_primitive_call(
-        self, input: Any, correlation_id: int | None = None
-    ) -> Call:
-        return Call(
-            correlation_id=correlation_id,
-            endpoint=self.endpoint,
-            function=self.name,
-            input=input,
-        )
-
 
 class Registry:
     """Registry of local functions."""
@@ -147,7 +161,7 @@ def __init__(self, endpoint: str, client: Client):
             client: Client for the Dispatch API. Used to dispatch calls to
                 local functions.
         """
-        self._functions: Dict[str, Function] = {}
+        self._functions: Dict[str, PrimitiveFunction] = {}
         self._endpoint = endpoint
         self._client = client
 
@@ -166,10 +180,6 @@ def function(self, func):
         logger.info("registering coroutine: %s", func.__qualname__)
         return self._register_coroutine(func)
 
-    def primitive_function(self, func: PrimitiveFunctionType) -> Function:
-        """Decorator that registers primitive functions."""
-        return self._register_primitive_function(func)
-
     def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
         func = durable(func)
 
@@ -184,7 +194,8 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
     def _register_coroutine(
         self, func: Callable[P, Coroutine[Any, Any, T]]
     ) -> Function[P, T]:
-        logger.info("registering coroutine: %s", func.__qualname__)
+        name = func.__qualname__
+        logger.info("registering coroutine: %s", name)
 
         func = durable(func)
 
@@ -192,32 +203,31 @@ def _register_coroutine(
         def primitive_func(input: Input) -> Output:
             return OneShotScheduler(func).run(input)
 
-        primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
+        primitive_func.__qualname__ = f"{name}_primitive"
         primitive_func = durable(primitive_func)
 
-        return self._register(primitive_func, func)
+        wrapped_func = Function[P, T](
+            self._endpoint, self._client, name, primitive_func, func
+        )
+        self._register(name, wrapped_func)
+        return wrapped_func
 
-    def _register_primitive_function(
+    def primitive_function(
         self, primitive_func: PrimitiveFunctionType
-    ) -> Function[P, T]:
-        logger.info("registering primitive function: %s", primitive_func.__qualname__)
-        return self._register(primitive_func, func=None)
+    ) -> PrimitiveFunction:
+        """Decorator that registers primitive functions."""
+        name = primitive_func.__qualname__
+        logger.info("registering primitive function: %s", name)
+        wrapped_func = PrimitiveFunction(
+            self._endpoint, self._client, name, primitive_func
+        )
+        self._register(name, wrapped_func)
+        return wrapped_func
 
-    def _register(
-        self,
-        primitive_func: PrimitiveFunctionType,
-        func: Callable[P, Coroutine[Any, Any, T]] | None,
-    ) -> Function[P, T]:
-        name = func.__qualname__ if func else primitive_func.__qualname__
+    def _register(self, name: str, wrapped_func: PrimitiveFunction):
         if name in self._functions:
-            raise ValueError(
-                f"function or coroutine already registered with name '{name}'"
-            )
-        wrapped_func = Function[P, T](
-            self._endpoint, self._client, name, primitive_func, func
-        )
+            raise ValueError(f"function already registered with name '{name}'")
         self._functions[name] = wrapped_func
-        return wrapped_func
 
     def set_client(self, client: Client):
         """Set the Client instance used to dispatch calls to local functions."""