Skip to content

Commit

Permalink
[WIP] add waitable-set
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewagner committed Jan 10, 2025
1 parent 4f9af42 commit 72df7a9
Show file tree
Hide file tree
Showing 2 changed files with 306 additions and 157 deletions.
148 changes: 105 additions & 43 deletions design/mvp/canonical-abi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class CanonicalOptions:
class ComponentInstance:
resources: Table[ResourceHandle]
waitables: Table[Waitable]
waitable_sets: Table[WaitableSet]
error_contexts: Table[ErrorContext]
num_tasks: int
may_leave: bool
Expand All @@ -217,6 +218,7 @@ class ComponentInstance:
def __init__(self):
self.resources = Table[ResourceHandle]()
self.waitables = Table[Waitable]()
self.waitable_sets = Table[WaitableSet]()
self.error_contexts = Table[ErrorContext]()
self.num_tasks = 0
self.may_leave = True
Expand All @@ -233,7 +235,7 @@ class Table(Generic[ElemT]):
array: list[Optional[ElemT]]
free: list[int]

MAX_LENGTH = 2**30 - 1
MAX_LENGTH = 2**28 - 1

def __init__(self):
self.array = [None]
Expand Down Expand Up @@ -345,9 +347,11 @@ class Task:
caller: Optional[Task]
on_return: Optional[Callable]
on_block: Callable[[Awaitable], Awaitable]
waitable_set: WaitableSet
num_subtasks: int
num_borrows: int
context: list[int]

NUM_CONTEXT = 2

def __init__(self, opts, inst, ft, caller, on_return, on_block):
self.opts = opts
Expand All @@ -356,9 +360,9 @@ def __init__(self, opts, inst, ft, caller, on_return, on_block):
self.caller = caller
self.on_return = on_return
self.on_block = on_block
self.waitable_set = WaitableSet()
self.num_subtasks = 0
self.num_borrows = 0
self.context = [0] * Task.NUM_CONTEXT

current = asyncio.Lock()

Expand Down Expand Up @@ -418,13 +422,6 @@ def maybe_start_pending_task(self):
self.inst.starting_pending_task = True
pending_future.set_result(None)

async def wait(self, sync) -> EventTuple:
return await self.wait_on(sync, self.waitable_set.wait())

async def poll(self, sync) -> Optional[EventTuple]:
await self.yield_(sync)
return self.waitable_set.poll()

async def yield_(self, sync):
await self.wait_on(sync, asyncio.sleep(0))

Expand Down Expand Up @@ -471,7 +468,6 @@ def return_(self, flat_results):
def exit(self):
assert(Task.current.locked())
trap_if(self.num_subtasks > 0)
self.waitable_set.drop()
trap_if(self.on_return)
assert(self.num_borrows == 0)
trap_if(self.inst.num_tasks == 1 and self.inst.backpressure)
Expand All @@ -493,7 +489,7 @@ class EventCode(IntEnum):
CALL_STARTING = CallState.STARTING
CALL_STARTED = CallState.STARTED
CALL_RETURNED = CallState.RETURNED
YIELDED = 3
NONE = 3
STREAM_READ = 4
STREAM_WRITE = 5
FUTURE_READ = 6
Expand Down Expand Up @@ -590,7 +586,6 @@ def add_to_waitables(self, task):
self.supertask = task
self.supertask.num_subtasks += 1
Waitable.__init__(self)
Waitable.set_waitable_set(self, task.waitable_set)
return task.inst.waitables.add(self)

def add_lender(self, lending_handle):
Expand Down Expand Up @@ -1722,19 +1717,44 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo
[] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, []))
else:
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
assert(types_match_values(flat_ft.results, [packed_ctx]))
while packed_ctx != 0:
is_yield = bool(packed_ctx & 1)
ctx = packed_ctx & ~1
if is_yield:
await task.yield_(sync = False)
event, p1, p2 = (EventCode.YIELDED, 0, 0)
[packed] = await call_and_trap_on_throw(callee, task, flat_args)
while True:
code,si = unpack_callback_result(packed)
if si != 0:
s = task.inst.waitable_sets.get(si)
match code:
case CallbackCode.EXIT:
break
case CallbackCode.WAIT:
e = await task.wait_on(opts.sync, s.wait())
case CallbackCode.POLL:
await task.yield_(opts.sync)
e = s.poll()
case CallbackCode.YIELD:
await task.yield_(opts.sync)
e = None
if e:
event, p1, p2 = e
else:
event, p1, p2 = await task.wait(sync = False)
[packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2])
event, p1, p2 = (EventCode.NONE, 0, 0)
[packed] = await call_and_trap_on_throw(opts.callback, task, [event, p1, p2])
task.exit()

class CallbackCode(IntEnum):
EXIT = 0
WAIT = 1
POLL = 2
YIELD = 3
MAX = 3

def unpack_callback_result(packed):
code = packed & 0xf
trap_if(code > CallbackCode.MAX)
assert(packed < 2**32)
assert(Table.MAX_LENGTH < 2**28)
waitable_set_index = packed >> 4
return (CallbackCode(code), waitable_set_index)

async def call_and_trap_on_throw(callee, task, args):
try:
return await callee(task, args)
Expand Down Expand Up @@ -1788,9 +1808,9 @@ def subtask_event():
subtask.finish()
return (EventCode(subtask.state), subtaski, 0)
subtask.set_event(subtask_event)
assert(0 < subtaski <= Table.MAX_LENGTH < 2**30)
assert(0 <= int(subtask.state) < 2**2)
flat_results = [subtaski | (int(subtask.state) << 30)]
assert(0 < subtaski <= Table.MAX_LENGTH < 2**28)
assert(0 <= int(subtask.state) < 2**4)
flat_results = [int(subtask.state) | (subtaski << 4)]

return flat_results

Expand Down Expand Up @@ -1836,6 +1856,20 @@ async def canon_resource_rep(rt, task, i):
trap_if(h.rt is not rt)
return [h.rep]

### 🔀 `canon task.set-context`

async def canon_task_set_context(i, task, v):
assert(types_match_values(['i32'], [v]))
trap_if(i >= Task.NUM_CONTEXT)
task.context[i] = v
return []

### 🔀 `canon task.get-context`

async def canon_task_get_context(i, task):
trap_if(i >= Task.NUM_CONTEXT)
return [task.context[i]]

### 🔀 `canon task.backpressure`

async def canon_task_backpressure(task, flat_args):
Expand All @@ -1853,35 +1887,65 @@ async def canon_task_return(task, result_type, opts, flat_args):
task.return_(flat_args)
return []

### 🔀 `canon task.wait`
### 🔀 `canon task.yield`

async def canon_task_yield(sync, task):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
await task.yield_(sync)
return []

### 🔀 `canon waitable-set.new`

async def canon_task_wait(sync, mem, task, ptr):
async def canon_waitable_set_new(task):
trap_if(not task.inst.may_leave)
return [ task.inst.waitable_sets.add(WaitableSet()) ]

### 🔀 `canon waitable-set.wait`

async def canon_waitable_set_wait(sync, mem, task, si, ptr):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
event, p1, p2 = await task.wait(sync)
s = task.inst.waitable_sets.get(si)
e = await task.wait_on(sync, s.wait())
return unpack_event(mem, task, ptr, e)

def unpack_event(mem, task, ptr, e: EventTuple):
event, p1, p2 = e
cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
store(cx, p1, U32Type(), ptr)
store(cx, p2, U32Type(), ptr + 4)
return [event]

### 🔀 `canon task.poll`
### 🔀 `canon waitable-set.poll`

async def canon_task_poll(sync, mem, task, ptr):
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
ret = await task.poll(sync)
if ret is None:
return [0]
cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
store(cx, ret, TupleType([U32Type(), U32Type(), U32Type()]), ptr)
return [1]
s = task.inst.waitable_sets.get(si)
await task.yield_(sync)
if (e := s.poll()):
return unpack_event(mem, task, ptr, e)
return [EventCode.NONE]

### 🔀 `canon task.yield`
### 🔀 `canon waitable-set.drop`

async def canon_task_yield(sync, task):
async def canon_waitable_set_drop(task, i):
trap_if(not task.inst.may_leave)
trap_if(task.opts.callback and not sync)
await task.yield_(sync)
s = task.inst.waitable_sets.remove(i)
s.drop()
return []

### 🔀 `canon waitable.set`

async def canon_waitable_set(task, wi, si):
trap_if(not task.inst.may_leave)
w = task.inst.waitables.get(wi)
if si == 0:
w.set_waitable_set(None)
else:
s = task.inst.waitable_sets.get(si)
w.set_waitable_set(s)
return []

### 🔀 `canon subtask.drop`
Expand Down Expand Up @@ -1948,15 +2012,13 @@ def on_partial_copy(revoke_buffer):
def copy_event(revoke_buffer):
revoke_buffer()
e.copying = False
e.set_waitable_set(None)
return (event_code, i, pack_copy_result(task, buffer, e))
def on_partial_copy(revoke_buffer):
e.set_event(partial(copy_event, revoke_buffer))
def on_copy_done():
e.set_event(partial(copy_event, revoke_buffer = lambda:()))
if e.copy(buffer, on_partial_copy, on_copy_done) != 'done':
e.copying = True
e.set_waitable_set(task.waitable_set)
return [BLOCKED]
return [pack_copy_result(task, buffer, e)]

Expand Down
Loading

0 comments on commit 72df7a9

Please sign in to comment.