Skip to content

Commit

Permalink
[WIP] add waitable-set
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Dec 31, 2024
1 parent e716109 commit 44eea5b
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 183 deletions.
217 changes: 146 additions & 71 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class CanonicalOptions:
class ComponentInstance:
resources: Table[ResourceHandle]
waitables: Table[Waitable]
waitable_sets: Table[WaitableSet]
error_contexts: Table[ErrorContext]
num_tasks: int
may_leave: bool
Expand All @@ -217,6 +218,7 @@ class ComponentInstance:
def __init__(self):
self.resources = Table[ResourceHandle]()
self.waitables = Table[Waitable]()
self.waitable_sets = Table[WaitableSet]()
self.error_contexts = Table[ErrorContext]()
self.num_tasks = 0
self.may_leave = True
Expand Down Expand Up @@ -292,36 +294,37 @@ def __init__(self, impl, dtor = None, dtor_sync = True, dtor_callback = None):
#### Waitable State

class CallState(IntEnum):
STARTING = 0
STARTED = 1
RETURNED = 2
STARTING = 1
STARTED = 2
RETURNED = 3

class EventCode(IntEnum):
NONE = 0
CALL_STARTING = CallState.STARTING
CALL_STARTED = CallState.STARTED
CALL_RETURNED = CallState.RETURNED
YIELDED = 3
STREAM_READ = 4
STREAM_WRITE = 5
FUTURE_READ = 6
FUTURE_WRITE = 7
YIELDED = 4
STREAM_READ = 5
STREAM_WRITE = 6
FUTURE_READ = 7
FUTURE_WRITE = 8

EventTuple = tuple[EventCode, int, int]
GetEventCallback = Callable[[], EventTuple]

class Waitable:
maybe_task: Optional[Task]
waitable_set: Optional[WaitableSet]
get_event_cb: Optional[GetEventCallback]
event: asyncio.Event

def __init__(self):
self.maybe_task = None
self.waitable_set = None
self.get_event_cb = None
self.event = asyncio.Event()

def set_event(self, get_event_cb: GetEventCallback):
assert(self.maybe_task)
self.maybe_task.maybe_has_event.set()
if self.waitable_set:
self.waitable_set.maybe_has_event.set()
self.get_event_cb = get_event_cb
self.event.set()

Expand All @@ -338,16 +341,44 @@ def get_event(self) -> EventTuple:
self.event.clear()
return get_event_cb()

def set_task(self, maybe_task):
if self.maybe_task:
self.maybe_task.waitables.remove(self)
self.maybe_task = maybe_task
if self.maybe_task:
self.maybe_task.waitables.append(self)
def set_waitable_set(self, s):
if self.waitable_set:
self.waitable_set.elems.remove(self)
self.waitable_set = s
if s:
s.elems.append(self)
if self.has_event():
s.maybe_has_event.set()

def drop(self):
assert(not self.has_event())
self.set_task(None)
self.set_waitable_set(None)

class WaitableSet:
elems: list[Waitable]
maybe_has_event: asyncio.Event

def __init__(self):
self.elems = []
self.maybe_has_event = asyncio.Event()

async def wait(self) -> EventTuple:
while True:
await self.maybe_has_event.wait()
if (e := self.poll()):
return e

def poll(self) -> Optional[EventTuple]:
random.shuffle(self.elems)
for w in self.elems:
if w.has_event():
assert(self.maybe_has_event.is_set())
return w.get_event()
self.maybe_has_event.clear()
return None

def drop(self):
trap_if(len(self.elems) > 0)

#### Task State

Expand All @@ -358,10 +389,11 @@ class Task:
caller: Optional[Task]
on_return: Optional[Callable]
on_block: Callable[[Awaitable], Awaitable]
waitables: list[Waitables]
maybe_has_event: asyncio.Event
num_subtasks: int
num_borrows: int
context: list[int]

NUM_CONTEXT = 2

def __init__(self, opts, inst, ft, caller, on_return, on_block):
self.opts = opts
Expand All @@ -370,10 +402,9 @@ def __init__(self, opts, inst, ft, caller, on_return, on_block):
self.caller = caller
self.on_return = on_return
self.on_block = on_block if on_block else Task.sync_on_block
self.waitables = []
self.maybe_has_event = asyncio.Event()
self.num_subtasks = 0
self.num_borrows = 0
self.context = [0] * Task.NUM_CONTEXT

current = asyncio.Lock()
async def sync_on_block(a: Awaitable):
Expand Down Expand Up @@ -432,28 +463,9 @@ def maybe_start_pending_task(self):
self.inst.starting_pending_task = True
pending_future.set_result(None)

async def wait(self, sync) -> EventTuple:
while True:
await self.wait_on(sync, self.maybe_has_event.wait())
if (e := self.maybe_next_event()):
return e

def maybe_next_event(self) -> Optional[EventTuple]:
random.shuffle(self.waitables)
for w in self.waitables:
if w.has_event():
assert(self.maybe_has_event.is_set())
return w.get_event()
self.maybe_has_event.clear()
return None

async def yield_(self, sync):
await self.wait_on(sync, asyncio.sleep(0))

async def poll(self, sync) -> Optional[EventTuple]:
await self.yield_(sync)
return self.maybe_next_event()

async def call_sync(self, callee, *args):
if self.inst.interruptible.is_set():
self.inst.interruptible.clear()
Expand Down Expand Up @@ -497,7 +509,6 @@ def return_(self, flat_results):
def exit(self):
assert(Task.current.locked())
trap_if(self.num_subtasks > 0)
trap_if(len(self.waitables) > 0)
trap_if(self.on_return)
assert(self.num_borrows == 0)
trap_if(self.inst.num_tasks == 1 and self.inst.backpressure)
Expand Down Expand Up @@ -527,7 +538,6 @@ def add_waitable(self, task):
self.supertask = task
self.supertask.num_subtasks += 1
Waitable.__init__(self)
Waitable.set_task(self, task)
return task.inst.waitables.add(self)

def add_lender(self, lending_handle):
Expand Down Expand Up @@ -1712,19 +1722,42 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo
[] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, []))
else:
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, [packed_ctx]))
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
if is_yield:
await task.yield_(sync = False)
event, p1, p2 = (EventCode.YIELDED, 0, 0)
[packed] = await call_and_trap_on_throw(callee, task, flat_args)
while True:
code,si = unpack_callback_result(packed)
if si != 0:
s = task.inst.waitable_sets.get(si)
match code:
case CallbackCode.EXIT:
break
case CallbackCode.WAIT:
e = await task.wait_on(opts.sync, s.wait())
case CallbackCode.POLL:
await task.yield_(opts.sync)
e = s.poll()
case CallbackCode.YIELD:
await task.yield_(opts.sync)
e = None
if e:
event, p1, p2 = e
else:
event, p1, p2 = await task.wait(sync = False)
[packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2])
event, p1, p2 = (EventCode.YIELDED, 0, 0)
[packed] = await call_and_trap_on_throw(opts.callback, task, [event, p1, p2])
task.exit()

class CallbackCode(IntEnum):
EXIT = 0
WAIT = 1
POLL = 2
YIELD = 3

def unpack_callback_result(packed):
code = packed & 3
assert(packed < 2**32)
assert(Table.MAX_LENGTH < 2**30)
waitable_set_index = packed >> 2
return (CallbackCode(code), waitable_set_index)

async def call_and_trap_on_throw(callee, task, args):
try:
return await callee(task, args)
Expand Down Expand Up @@ -1826,6 +1859,20 @@ async def canon_resource_rep(rt, task, i):
trap_if(h.rt is not rt)
return [h.rep]

### 🔀 `canon task.set-context`

async def canon_task_set_context(i, task, v):
assert(types_match_values(['i32'], [v]))
trap_if(i >= Task.NUM_CONTEXT)
task.context[i] = v
return []

### 🔀 `canon task.get-context`

async def canon_task_get_context(i, task):
trap_if(i >= Task.NUM_CONTEXT)
return [task.context[i]]

### 🔀 `canon task.backpressure`

async def canon_task_backpressure(task, flat_args):
Expand All @@ -1843,35 +1890,65 @@ async def canon_task_return(task, result_type, opts, flat_args):
task.return_(flat_args)
return []

### 🔀 `canon task.wait`
### 🔀 `canon task.yield`

async def canon_task_yield(sync, task):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
await task.yield_(sync)
return []

### 🔀 `canon waitable-set.new`

async def canon_task_wait(sync, mem, task, ptr):
async def canon_waitable_set_new(task):
trap_if(not task.inst.may_leave)
return [ task.inst.waitable_sets.add(WaitableSet()) ]

### 🔀 `canon waitable-set.wait`

async def canon_waitable_set_wait(sync, mem, task, si, ptr):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
event, p1, p2 = await task.wait(sync)
s = task.inst.waitable_sets.get(si)
e = await task.wait_on(sync, s.wait())
return unpack_event(mem, task, ptr, e)

def unpack_event(mem, task, ptr, e: EventTuple):
event, p1, p2 = e
cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
store(cx, p1, U32Type(), ptr)
store(cx, p2, U32Type(), ptr + 4)
return [event]

### 🔀 `canon task.poll`
### 🔀 `canon waitable-set.poll`

async def canon_task_poll(sync, mem, task, ptr):
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
ret = await task.poll(sync)
if ret is None:
return [0]
cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
store(cx, ret, TupleType([U32Type(), U32Type(), U32Type()]), ptr)
return [1]
s = task.inst.waitable_sets.get(si)
await task.yield_(sync)
if (e := s.poll()):
return unpack_event(mem, task, ptr, e)
return [EventCode.NONE]

### 🔀 `canon task.yield`
### 🔀 `canon waitable-set.drop`

async def canon_task_yield(sync, task):
async def canon_waitable_set_drop(task, i):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
await task.yield_(sync)
s = task.inst.waitable_sets.remove(i)
s.drop()
return []

### 🔀 `canon waitable.set`

async def canon_waitable_set(task, wi, si):
trap_if(not task.inst.may_leave)
w = task.inst.waitables.get(wi)
if si == 0:
w.set_waitable_set(None)
else:
s = task.inst.waitable_sets.get(si)
w.set_waitable_set(s)
return []

### 🔀 `canon subtask.drop`
Expand Down Expand Up @@ -1938,15 +2015,13 @@ def on_async_partial(revoke_buffer):
def copy_event(revoke_buffer):
revoke_buffer()
h.copying = False
h.set_task(None)
return (event_code, i, pack_copy_result(task, buffer, h))
def on_async_partial(revoke_buffer):
h.set_event(partial(copy_event, revoke_buffer))
def on_async_done():
h.set_event(partial(copy_event, revoke_buffer = lambda:()))
if h.copy(buffer, on_async_partial, on_async_done) != 'done':
h.copying = True
h.set_task(task)
return [BLOCKED]
return [pack_copy_result(task, buffer, h)]

Expand Down
Loading

0 comments on commit 44eea5b

Please sign in to comment.