Skip to content

Commit

Permalink
Extract a base class to better handle primitive functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Mar 4, 2024
1 parent 8fe9486 commit a9148a1
Showing 1 changed file with 78 additions and 68 deletions.
146 changes: 78 additions & 68 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,20 @@
"""


P = ParamSpec("P")
T = TypeVar("T")


class Function(Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""

__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
class PrimitiveFunction:
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")

def __init__(
self,
endpoint: str,
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable[..., Any] | None,
):
self._endpoint = endpoint
self._client = client
self._name = name
self._primitive_func = primitive_func
self._func: Callable[P, Coroutine[Any, Any, T]] | None = (
durable(self._call_async) if func else None
)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
if self._func is None:
raise ValueError("cannot call a primitive function directly")
return self._func(*args, **kwargs)

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)

@property
def endpoint(self) -> str:
Expand All @@ -76,8 +56,62 @@ def endpoint(self) -> str:
def name(self) -> str:
return self._name

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

def _build_primitive_call(
self, input: Any, correlation_id: int | None = None
) -> Call:
return Call(
correlation_id=correlation_id,
endpoint=self.endpoint,
function=self.name,
input=input,
)


P = ParamSpec("P")
T = TypeVar("T")


class Function(PrimitiveFunction, Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""

__slots__ = ("_func_indirect",)

def __init__(
self,
endpoint: str,
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
):
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)

self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable(
self._call_async
)

async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
"""Call the function asynchronously (through Dispatch), and return a
coroutine that can be awaited to retrieve the call result."""
return self._func_indirect(*args, **kwargs)

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch a call to the function.
"""Dispatch an asynchronous call to the function without
waiting for a result.
The Registry this function was registered with must be initialized
with a Client / api_key for this call facility to be available.
Expand All @@ -94,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""
return self._primitive_dispatch(Arguments(args, kwargs))

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Asynchronously call the function from a @dispatch.function."""
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)

def build_call(
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
) -> Call:
Expand All @@ -123,16 +147,6 @@ def build_call(
Arguments(args, kwargs), correlation_id=correlation_id
)

def _build_primitive_call(
self, input: Any, correlation_id: int | None = None
) -> Call:
return Call(
correlation_id=correlation_id,
endpoint=self.endpoint,
function=self.name,
input=input,
)


class Registry:
"""Registry of local functions."""
Expand All @@ -147,7 +161,7 @@ def __init__(self, endpoint: str, client: Client):
client: Client for the Dispatch API. Used to dispatch calls to
local functions.
"""
self._functions: Dict[str, Function] = {}
self._functions: Dict[str, PrimitiveFunction] = {}
self._endpoint = endpoint
self._client = client

Expand All @@ -166,10 +180,6 @@ def function(self, func):
logger.info("registering coroutine: %s", func.__qualname__)
return self._register_coroutine(func)

def primitive_function(self, func: PrimitiveFunctionType) -> Function:
"""Decorator that registers primitive functions."""
return self._register_primitive_function(func)

def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
func = durable(func)

Expand All @@ -184,40 +194,40 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def _register_coroutine(
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", func.__qualname__)
name = func.__qualname__
logger.info("registering coroutine: %s", name)

func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
primitive_func.__qualname__ = f"{name}_primitive"
primitive_func = durable(primitive_func)

return self._register(primitive_func, func)
wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func
)
self._register(name, wrapped_func)
return wrapped_func

def _register_primitive_function(
def primitive_function(
self, primitive_func: PrimitiveFunctionType
) -> Function[P, T]:
logger.info("registering primitive function: %s", primitive_func.__qualname__)
return self._register(primitive_func, func=None)
) -> PrimitiveFunction:
"""Decorator that registers primitive functions."""
name = primitive_func.__qualname__
logger.info("registering primitive function: %s", name)
wrapped_func = PrimitiveFunction(
self._endpoint, self._client, name, primitive_func
)
self._register(name, wrapped_func)
return wrapped_func

def _register(
self,
primitive_func: PrimitiveFunctionType,
func: Callable[P, Coroutine[Any, Any, T]] | None,
) -> Function[P, T]:
name = func.__qualname__ if func else primitive_func.__qualname__
def _register(self, name: str, wrapped_func: PrimitiveFunction):
if name in self._functions:
raise ValueError(
f"function or coroutine already registered with name '{name}'"
)
wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func
)
raise ValueError(f"function already registered with name '{name}'")
self._functions[name] = wrapped_func
return wrapped_func

def set_client(self, client: Client):
"""Set the Client instance used to dispatch calls to local functions."""
Expand Down

0 comments on commit a9148a1

Please sign in to comment.