Skip to content

Commit

Permalink
Merge pull request #90 from stealthrocket/poll-result-error
Browse files Browse the repository at this point in the history
Handle PollResult errors
  • Loading branch information
chriso authored Feb 24, 2024
2 parents 0a4391d + bba1353 commit 5aa45e0
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 26 deletions.
4 changes: 3 additions & 1 deletion src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ async def execute(request: fastapi.Request):
)

logger.debug("finished handling run request with status %s", status.name)
return fastapi.Response(content=response.SerializeToString())
return fastapi.Response(
content=response.SerializeToString(), media_type="application/proto"
)

return app
28 changes: 25 additions & 3 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ class Input:
This class is intended to be used as read-only.
"""

__slots__ = ("_has_input", "_input", "_coroutine_state", "_call_results")
__slots__ = (
"_has_input",
"_input",
"_coroutine_state",
"_call_results",
"_poll_error",
)

def __init__(self, req: function_pb.RunRequest):
self._has_input = req.HasField("input")
Expand All @@ -54,6 +60,11 @@ def __init__(self, req: function_pb.RunRequest):
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
self._poll_error = (
Error._from_proto(req.poll_result.error)
if req.poll_result.HasField("error")
else None
)

@property
def is_first_call(self) -> bool:
Expand Down Expand Up @@ -85,6 +96,11 @@ def call_results(self) -> list[CallResult]:
self._assert_resume()
return self._call_results

@property
def poll_error(self) -> Error | None:
self._assert_resume()
return self._poll_error

def _assert_first_call(self):
if self.is_resume:
raise ValueError("This input is for a resumed coroutine")
Expand All @@ -105,14 +121,20 @@ def from_input_arguments(cls, function: str, *args, **kwargs):

@classmethod
def from_poll_results(
cls, function: str, coroutine_state: Any, call_results: list[CallResult]
cls,
function: str,
coroutine_state: Any,
call_results: list[CallResult],
error: Error | None = None,
):
return Input(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
coroutine_state=coroutine_state,
results=[result._as_proto() for result in call_results],
)
error=error._as_proto() if error else None,
),
)
)

Expand Down
61 changes: 48 additions & 13 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CallResult:


class Future(Protocol):
def add(self, result: CallResult | CoroutineResult): ...
def add_result(self, result: CallResult | CoroutineResult): ...
def add_error(self, error: Exception): ...
def ready(self) -> bool: ...
def error(self) -> Exception | None: ...
def value(self) -> Any: ...
Expand All @@ -48,17 +49,25 @@ class CallFuture:
"""A future result of a dispatch.coroutine.call() operation."""

result: CallResult | None = None
first_error: Exception | None = None

def add(self, result: CallResult | CoroutineResult):
def add_result(self, result: CallResult | CoroutineResult):
assert isinstance(result, CallResult)
self.result = result
if self.result is None:
self.result = result
if result.error is not None and self.first_error is None:
self.first_error = result.error

def add_error(self, error: Exception):
if self.first_error is None:
self.first_error = error

def ready(self) -> bool:
return self.result is not None
return self.first_error is not None or self.result is not None

def error(self) -> Exception | None:
assert self.result is not None
return self.result.error
assert self.ready()
return self.first_error

def value(self) -> Any:
assert self.result is not None
Expand All @@ -74,7 +83,7 @@ class GatherFuture:
results: dict[CoroutineID, CoroutineResult]
first_error: Exception | None = None

def add(self, result: CallResult | CoroutineResult):
def add_result(self, result: CallResult | CoroutineResult):
assert isinstance(result, CoroutineResult)

try:
Expand All @@ -87,6 +96,10 @@ def add(self, result: CallResult | CoroutineResult):

self.results[result.coroutine_id] = result

def add_error(self, error: Exception):
if self.first_error is not None:
self.first_error = error

def ready(self) -> bool:
return self.first_error is not None or len(self.waiting) == 0

Expand Down Expand Up @@ -134,6 +147,8 @@ class State:
next_coroutine_id: int
next_call_id: int

prev_calls: list[Coroutine]


class OneShotScheduler:
"""Scheduler for local coroutines.
Expand Down Expand Up @@ -183,6 +198,7 @@ def _init_state(self, input: Input) -> State:
ready=[Coroutine(id=0, parent_id=None, coroutine=main)],
next_coroutine_id=1,
next_call_id=1,
prev_calls=[],
)

def _rebuild_state(self, input: Input):
Expand All @@ -203,19 +219,37 @@ def _rebuild_state(self, input: Input):
raise IncompatibleStateError from e

def _run(self, input: Input) -> Output:

if input.is_first_call:
state = self._init_state(input)
else:
state = self._rebuild_state(input)

poll_error = input.poll_error
if poll_error is not None:
error = poll_error.to_exception()
logger.debug("dispatching poll error: %s", error)
for coroutine in state.prev_calls:
future = coroutine.result
assert future is not None
future.add_error(error)
if future.ready() and coroutine.id in state.suspended:
state.ready.append(coroutine)
del state.suspended[coroutine.id]
logger.debug("coroutine %s is now ready", coroutine)

state.prev_calls = []

logger.debug("dispatching %d call result(s)", len(input.call_results))
for cr in input.call_results:
assert cr.correlation_id is not None
coroutine_id = correlation_coroutine_id(cr.correlation_id)
call_id = correlation_call_id(cr.correlation_id)

error = cr.error.to_exception() if cr.error is not None else None
call_result = CallResult(call_id=call_id, value=cr.output, error=error)
call_error = cr.error.to_exception() if cr.error is not None else None
call_result = CallResult(
call_id=call_id, value=cr.output, error=call_error
)

try:
owner = state.suspended[coroutine_id]
Expand All @@ -226,8 +260,8 @@ def _run(self, input: Input) -> Output:
continue

logger.debug("dispatching %s to %s", call_result, owner)
future.add(call_result)
if future.ready():
future.add_result(call_result)
if future.ready() and owner.id in state.suspended:
state.ready.append(owner)
del state.suspended[owner.id]
logger.debug("owner %s is now ready", owner)
Expand Down Expand Up @@ -284,8 +318,8 @@ def _run(self, input: Input) -> Output:
except (KeyError, AssertionError):
logger.warning("discarding %s", coroutine_result)
else:
future.add(coroutine_result)
if future.ready():
future.add_result(coroutine_result)
if future.ready() and parent.id in state.suspended:
state.ready.insert(0, parent)
del state.suspended[parent.id]
logger.debug("parent %s is now ready", parent)
Expand All @@ -308,6 +342,7 @@ def _run(self, input: Input) -> Output:
pending_calls.append(call)
coroutine.result = CallFuture()
state.suspended[coroutine.id] = coroutine
state.prev_calls.append(coroutine)

case Gather():
gather = coroutine_yield
Expand Down
11 changes: 6 additions & 5 deletions src/dispatch/sdk/v1/poll_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion src/dispatch/sdk/v1/poll_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ from google.protobuf.internal import containers as _containers

from buf.validate import validate_pb2 as _validate_pb2
from dispatch.sdk.v1 import call_pb2 as _call_pb2
from dispatch.sdk.v1 import error_pb2 as _error_pb2

DESCRIPTOR: _descriptor.FileDescriptor

Expand All @@ -33,13 +34,16 @@ class Poll(_message.Message):
) -> None: ...

class PollResult(_message.Message):
__slots__ = ("coroutine_state", "results")
__slots__ = ("coroutine_state", "results", "error")
COROUTINE_STATE_FIELD_NUMBER: _ClassVar[int]
RESULTS_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
coroutine_state: bytes
results: _containers.RepeatedCompositeFieldContainer[_call_pb2.CallResult]
error: _error_pb2.Error
def __init__(
self,
coroutine_state: _Optional[bytes] = ...,
results: _Optional[_Iterable[_Union[_call_pb2.CallResult, _Mapping]]] = ...,
error: _Optional[_Union[_error_pb2.Error, _Mapping]] = ...,
) -> None: ...
62 changes: 59 additions & 3 deletions tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from pprint import pprint
from typing import Any, Callable

from dispatch.coroutine import call, gather
Expand Down Expand Up @@ -221,6 +220,56 @@ async def main():

self.assertEqual(len(correlation_ids), 8)

def test_poll_error(self):
# The purpose of the test is to ensure that when a poll error occurs,
# we only abort the calls that were made on the previous yield. Any
# other in-flight calls from previous yields are not affected.

@durable
async def c_then_d():
c_result = await call_one("c")
try:
# The poll error will affect this call only.
d_result = await call_one("d")
except RuntimeError as e:
assert str(e) == "too many calls"
d_result = 100
return c_result + d_result

@durable
async def main(c_then_d):
return await gather(
call_concurrently("a", "b"),
c_then_d(),
)

output = self.start(main, c_then_d)
calls = self.assert_poll_call_functions(output, ["a", "b", "c"])

call_a, call_b, call_c = calls
a_result, b_result, c_result = 10, 20, 30
output = self.resume(
main,
output,
[CallResult.from_value(c_result, correlation_id=call_c.correlation_id)],
)
self.assert_poll_call_functions(output, ["d"])

output = self.resume(
main, output, [], poll_error=RuntimeError("too many calls")
)
self.assert_poll_call_functions(output, [])
output = self.resume(
main,
output,
[
CallResult.from_value(a_result, correlation_id=call_a.correlation_id),
CallResult.from_value(b_result, correlation_id=call_b.correlation_id),
],
)

self.assert_exit_result_value(output, [[a_result, b_result], c_result + 100])

def test_raise_indirect(self):
@durable
async def main():
Expand All @@ -234,11 +283,18 @@ def start(self, main: Callable, *args: Any, **kwargs: Any) -> Output:
return OneShotScheduler(main).run(input)

def resume(
self, main: Callable, prev_output: Output, call_results: list[CallResult]
self,
main: Callable,
prev_output: Output,
call_results: list[CallResult],
poll_error: Exception | None = None,
):
poll = self.assert_poll(prev_output)
input = Input.from_poll_results(
main.__qualname__, poll.coroutine_state, call_results
main.__qualname__,
poll.coroutine_state,
call_results,
Error.from_exception(poll_error) if poll_error else None,
)
return OneShotScheduler(main).run(input)

Expand Down

0 comments on commit 5aa45e0

Please sign in to comment.