Skip to content

Commit

Permalink
Merge pull request #127 from stealthrocket/batch-submit
Browse files Browse the repository at this point in the history
Batch submit
  • Loading branch information
chriso authored Mar 18, 2024
2 parents e0a1b1d + ff4be5e commit 8aa6795
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 128 deletions.
2 changes: 1 addition & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fastapi.testclient import TestClient

from dispatch.client import Client
from dispatch.function import Client
from dispatch.test import DispatchServer, DispatchService, EndpointClient


Expand Down
2 changes: 1 addition & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

import dispatch.integrations
from dispatch.client import DEFAULT_API_URL, Client
from dispatch.coroutine import call, gather
from dispatch.function import DEFAULT_API_URL, Client
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand Down
118 changes: 0 additions & 118 deletions src/dispatch/client.py

This file was deleted.

14 changes: 9 additions & 5 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def read_root():
import fastapi.responses
from http_message_signatures import InvalidSignature

from dispatch.client import Client
from dispatch.function import Registry
from dispatch.function import Batch, Client, Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand All @@ -47,7 +46,7 @@ def read_root():
class Dispatch(Registry):
"""A Dispatch programmable endpoint, powered by FastAPI."""

__slots__ = ()
__slots__ = ("client",)

def __init__(
self,
Expand Down Expand Up @@ -116,12 +115,17 @@ def __init__(
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
)

client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, client)
self.client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, self.client)

function_service = _new_app(self, verification_key)
app.mount("/dispatch.sdk.v1.FunctionService", function_service)

def batch(self) -> Batch:
"""Returns a Batch instance that can be used to build
a set of calls to dispatch."""
return self.client.batch()


def parse_verification_key(
verification_key: Ed25519PublicKey | str | bytes | None,
Expand Down
155 changes: 154 additions & 1 deletion src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import logging
import os
from functools import wraps
from types import CoroutineType
from typing import (
Expand All @@ -10,14 +11,19 @@
Coroutine,
Dict,
Generic,
Iterable,
ParamSpec,
TypeAlias,
TypeVar,
overload,
)
from urllib.parse import urlparse

import grpc

import dispatch.coroutine
from dispatch.client import Client
import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb
import dispatch.sdk.v1.dispatch_pb2_grpc as dispatch_grpc
from dispatch.experimental.durable import durable
from dispatch.id import DispatchID
from dispatch.proto import Arguments, Call, Error, Input, Output
Expand All @@ -33,6 +39,9 @@
"""


DEFAULT_API_URL = "https://api.dispatch.run"


class PrimitiveFunction:
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")

Expand Down Expand Up @@ -234,3 +243,147 @@ def set_client(self, client: Client):
self._client = client
for fn in self._functions.values():
fn._client = client


class Client:
"""Client for the Dispatch API."""

__slots__ = ("api_url", "api_key", "_stub", "api_key_from")

def __init__(self, api_key: None | str = None, api_url: None | str = None):
"""Create a new Dispatch client.
Args:
api_key: Dispatch API key to use for authentication. Uses the value of
the DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).
Raises:
ValueError: if the API key is missing.
"""

if api_key:
self.api_key_from = "api_key"
else:
self.api_key_from = "DISPATCH_API_KEY"
api_key = os.environ.get("DISPATCH_API_KEY")
if not api_key:
raise ValueError(
"missing API key: set it with the DISPATCH_API_KEY environment variable"
)

if not api_url:
api_url = os.environ.get("DISPATCH_API_URL", DEFAULT_API_URL)
if not api_url:
raise ValueError(
"missing API URL: set it with the DISPATCH_API_URL environment variable"
)

logger.debug("initializing client for Dispatch API at URL %s", api_url)
self.api_url = api_url
self.api_key = api_key
self._init_stub()

def __getstate__(self):
return {"api_url": self.api_url, "api_key": self.api_key}

def __setstate__(self, state):
self.api_url = state["api_url"]
self.api_key = state["api_key"]
self._init_stub()

def _init_stub(self):
result = urlparse(self.api_url)
match result.scheme:
case "http":
creds = grpc.local_channel_credentials()
case "https":
creds = grpc.ssl_channel_credentials()
case _:
raise ValueError(f"Invalid API scheme: '{result.scheme}'")

call_creds = grpc.access_token_call_credentials(self.api_key)
creds = grpc.composite_channel_credentials(creds, call_creds)
channel = grpc.secure_channel(result.netloc, creds)

self._stub = dispatch_grpc.DispatchServiceStub(channel)

def batch(self) -> Batch:
"""Returns a Batch instance that can be used to build
a set of calls to dispatch."""
return Batch(self)

def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]:
"""Dispatch function calls.
Args:
calls: Calls to dispatch.
Returns:
Identifiers for the function calls, in the same order as the inputs.
"""
calls_proto = [c._as_proto() for c in calls]
logger.debug("dispatching %d function call(s)", len(calls_proto))
req = dispatch_pb.DispatchRequest(calls=calls_proto)

try:
resp = self._stub.Dispatch(req)
except grpc.RpcError as e:
status_code = e.code()
match status_code:
case grpc.StatusCode.UNAUTHENTICATED:
raise PermissionError(
f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)"
) from e
raise

dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"dispatched %d function call(s): %s",
len(calls_proto),
", ".join(dispatch_ids),
)
return dispatch_ids


class Batch:
"""A batch of calls to dispatch."""

__slots__ = ("client", "calls")

def __init__(self, client: Client):
self.client = client
self.calls: list[Call] = []

def add(self, func: Function[P, T], *args: P.args, **kwargs: P.kwargs) -> Batch:
"""Add a call to the specified function to the batch."""
return self.add_call(func.build_call(*args, correlation_id=None, **kwargs))

def add_call(self, call: Call) -> Batch:
"""Add a Call to the batch."""
self.calls.append(call)
return self

def dispatch(self) -> list[DispatchID]:
"""Dispatch dispatches the calls asynchronously.
The batch is reset when the calls are dispatched successfully.
Returns:
Identifiers for the function calls, in the same order they
were added.
"""
if not self.calls:
return []

dispatch_ids = self.client.dispatch(self.calls)
self.reset()
return dispatch_ids

def reset(self):
"""Reset the batch."""
self.calls = []
3 changes: 1 addition & 2 deletions tests/dispatch/test_function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pickle
import unittest

from dispatch.client import Client
from dispatch.function import Registry
from dispatch.function import Client, Registry


class TestFunction(unittest.TestCase):
Expand Down
Loading

0 comments on commit 8aa6795

Please sign in to comment.