Skip to content

Commit

Permalink
feat: add naive broker
Browse files Browse the repository at this point in the history
  • Loading branch information
gusye1234 committed Sep 10, 2024
1 parent 2343e3b commit 924c577
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 80 deletions.
11 changes: 11 additions & 0 deletions drive_events/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any
from .types import BaseEvent, EventInput, Task, GroupEventReturns
from .utils import generate_uuid


class BaseBroker:
async def append(self, event: BaseEvent, event_input: EventInput) -> Task:
raise NotImplementedError()

async def callback_after_run_done(self) -> tuple[BaseEvent, Any]:
raise NotImplementedError()
92 changes: 73 additions & 19 deletions drive_events/core.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import inspect
from typing import Callable, Optional
from .types import BaseEvent, EventFunction, EventGroup
from .utils import logger
import asyncio
from typing import Callable, Optional, Union, Any, Tuple
from .types import BaseEvent, EventFunction, EventGroup, EventInput
from .broker import BaseBroker
from .utils import (
logger,
string_to_md5_hash,
)


class EventEngineCls:
def __init__(self):
self.__function_maps: dict[str, EventFunction] = {}
def __init__(self, name="default", broker: Optional[BaseBroker] = None):
self.name = name
self.broker = broker or BaseBroker()
self.__event_maps: dict[str, BaseEvent] = {}
self.__max_group_size = 0

def reset(self):
self.__function_maps = {}
self.__event_maps = {}

def make_event(self, func: Union[EventFunction, BaseEvent]) -> BaseEvent:
if isinstance(func, BaseEvent):
self.__event_maps[func.id] = func
return func
assert inspect.iscoroutinefunction(
func
), "Event function must be a coroutine function"
event = BaseEvent(func)
self.__event_maps[event.id] = event
return event

def listen_groups(
self, group_markers: list[BaseEvent], group_name: Optional[str] = None
) -> Callable[[BaseEvent], BaseEvent]:
Expand All @@ -21,14 +38,23 @@ def listen_groups(
), "group_markers must be a list of BaseEvent"
assert all(
[m.id in self.__event_maps for m in group_markers]
), "group_markers must be registered in the same event engine"
group_markers = list(set(group_markers))
), f"group_markers must be registered in the same event engine, current event engine is {self.name}"
group_markers_in_dict = {event.id: event for event in group_markers}

def decorator(func: BaseEvent) -> BaseEvent:
if not isinstance(func, BaseEvent):
func = self.make_event(func)
assert (
func.id in self.__event_maps
), f"Event function must be registered in the same event engine, current event engine is {self.name}"
this_group_name = group_name or f"{len(func.parent_groups)}"
new_group = EventGroup(this_group_name, group_markers)
this_group_hash = string_to_md5_hash(":".join(group_markers_in_dict.keys()))
new_group = EventGroup(
this_group_name, this_group_hash, group_markers_in_dict
)
self.__max_group_size = max(
self.__max_group_size, len(group_markers_in_dict)
)
if new_group.hash() in func.parent_groups:
logger.warning(f"Group {group_markers} already listened by {func}")
return func
Expand All @@ -40,13 +66,41 @@ def decorator(func: BaseEvent) -> BaseEvent:
def goto(self, group_markers: list[BaseEvent], *args):
raise NotImplementedError()

def make_event(self, func: EventFunction) -> BaseEvent:
if isinstance(func, BaseEvent):
return func
assert inspect.iscoroutinefunction(
func
), "Event function must be a coroutine function"
event = BaseEvent(func)
self.__function_maps[event.id] = func
self.__event_maps[event.id] = event
return event
async def invoke_event(
self,
event: BaseEvent,
event_input: Optional[EventInput] = None,
global_ctx: Any = None,
max_async_events: Optional[int] = None,
) -> dict[str, Any]:
this_run_ctx = {}
queue: list[Tuple[BaseEvent, EventInput]] = [(event, event_input)]

async def run_event(current_event, current_event_input):
result = await current_event.solo_run(current_event_input, global_ctx)
this_run_ctx[current_event.id] = result
for cand_event in self.__event_maps.values():
cand_event_parents = cand_event.parent_groups
for group_hash, group in cand_event_parents.items():
if current_event.id in group.events and all(
[event_id in this_run_ctx for event_id in group.events]
):
this_group_returns = {
event_id: this_run_ctx[event_id]
for event_id in group.events
}
build_input = EventInput(
group_name=group.name, results=this_group_returns
)
queue.append((cand_event, build_input))

while len(queue):
this_batch_events = queue[:max_async_events] if max_async_events else queue
queue = queue[max_async_events:] if max_async_events else []
logger.debug(
f"Running a turn with {len(this_batch_events)} event tasks, left {len(queue)} event tasks in queue"
)
await asyncio.gather(
*[run_event(*run_event_input) for run_event_input in this_batch_events]
)
return this_run_ctx
66 changes: 48 additions & 18 deletions drive_events/types.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,56 @@
from copy import copy
from enum import Enum
from dataclasses import dataclass
from typing import Callable, Any, Awaitable, Optional
from dataclasses import dataclass, field
from datetime import datetime
from typing import Callable, Any, Awaitable, Optional, TypeVar, Generic

from .utils import (
string_to_md5_hash,
function_or_method_to_string,
function_or_method_to_repr,
)

GroupEventReturns = dict["BaseEvent", Any]
EventInput = tuple[str, GroupEventReturns]

EventFunction = Callable[[EventInput], Awaitable[Any]]
class ReturnBehavior(Enum):
DISPATCH = "dispatch"
GOTO = "goto"
ABORT = "abort"


class TaskStatus(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILURE = "failure"
PENDING = "pending"


GroupEventReturns = dict[str, Any]


@dataclass
class EventGroupInput:
group_name: str
results: GroupEventReturns
behavior: ReturnBehavior = ReturnBehavior.DISPATCH


@dataclass
class EventInput(EventGroupInput):
pass


# (group_event_results, global ctx set by user) -> result
EventFunction = Callable[[Optional[EventInput], Optional[Any]], Awaitable[Any]]


@dataclass
class EventGroup:
name: str
events: list["BaseEvent"]

def __post_init__(self):
self.events = sorted(self.events, key=lambda e: e.id)
self._hash = string_to_md5_hash(":".join([e.id for e in self.events]))
events_hash: str
events: dict[str, "BaseEvent"]

def hash(self) -> str:
return self._hash
return self.events_hash


class BaseEvent:
Expand All @@ -43,6 +68,7 @@ def __init__(
self.func_inst = func_inst
self.id = string_to_md5_hash(function_or_method_to_string(self.func_inst))
self.repr_name = function_or_method_to_repr(self.func_inst)
self.meta = {"func_body": function_or_method_to_string(self.func_inst)}

def debug_string(self, exclude_events: Optional[set[str]] = None) -> str:
exclude_events = exclude_events or set([self.id])
Expand All @@ -52,14 +78,18 @@ def debug_string(self, exclude_events: Optional[set[str]] = None) -> str:
def __repr__(self) -> str:
return f"Node(source={self.repr_name})"

async def solo_run(self, event_input: EventInput) -> Awaitable[Any]:
return await self.func_inst(event_input)
async def solo_run(
self, event_input: EventInput, global_ctx: Any = None
) -> Awaitable[Any]:
return await self.func_inst(event_input, global_ctx)


class ReturnBehavior(Enum):
DISPATCH = "dispatch"
GOTO = "goto"
ABORT = "abort"
@dataclass
class Task:
task_id: str
status: TaskStatus = TaskStatus.PENDING
created_at: datetime = field(default_factory=datetime.now)
upated_at: datetime = field(default_factory=datetime.now)


def format_parents(parents: dict[str, EventGroup], exclude_events: set[str], indent=""):
Expand All @@ -70,7 +100,7 @@ def format_parents(parents: dict[str, EventGroup], exclude_events: set[str], ind
is_last_group = i == len(parents) - 1
group_prefix = "└─ " if is_last_group else "├─ "
result.append(indent + group_prefix + f"<{parent_group.name}>")
for j, parent in enumerate(parent_group.events):
for j, parent in enumerate(parent_group.events.values()):
root_events = copy(exclude_events)
is_last = j == len(parent_group.events) - 1
child_indent = indent + (" " if is_last_group else "│ ")
Expand Down
13 changes: 4 additions & 9 deletions drive_events/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import uuid
import logging
import asyncio
import inspect
import hashlib
from typing import Callable, Union
from types import MethodType
from typing import Callable

logger = logging.getLogger("drive-events")


def always_get_a_event_loop():
try:
return asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def generate_uuid() -> str:
return str(uuid.uuid4())


def function_or_method_to_repr(func_or_method: Callable) -> str:
Expand Down
Loading

0 comments on commit 924c577

Please sign in to comment.