From 44eea5b1a212a800eb7052649c6578ef7625dd63 Mon Sep 17 00:00:00 2001
From: Luke Wagner <mail@lukewagner.name>
Date: Mon, 9 Dec 2024 17:19:50 -0600
Subject: [PATCH] [WIP] add waitable-set

---
 design/mvp/canonical-abi/definitions.py | 217 ++++++++++------
 design/mvp/canonical-abi/run_tests.py   | 312 +++++++++++++++---------
 2 files changed, 346 insertions(+), 183 deletions(-)

diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py
index 3a7a39b0..dbd89297 100644
--- a/design/mvp/canonical-abi/definitions.py
+++ b/design/mvp/canonical-abi/definitions.py
@@ -206,6 +206,7 @@ class CanonicalOptions:
 class ComponentInstance:
   resources: Table[ResourceHandle]
   waitables: Table[Waitable]
+  waitable_sets: Table[WaitableSet]
   error_contexts: Table[ErrorContext]
   num_tasks: int
   may_leave: bool
@@ -217,6 +218,7 @@ class ComponentInstance:
   def __init__(self):
     self.resources = Table[ResourceHandle]()
     self.waitables = Table[Waitable]()
+    self.waitable_sets = Table[WaitableSet]()
     self.error_contexts = Table[ErrorContext]()
     self.num_tasks = 0
     self.may_leave = True
@@ -292,36 +294,37 @@ def __init__(self, impl, dtor = None, dtor_sync = True, dtor_callback = None):
 #### Waitable State
 
 class CallState(IntEnum):
-  STARTING = 0
-  STARTED = 1
-  RETURNED = 2
+  STARTING = 1
+  STARTED = 2
+  RETURNED = 3
 
 class EventCode(IntEnum):
+  NONE = 0
   CALL_STARTING = CallState.STARTING
   CALL_STARTED = CallState.STARTED
   CALL_RETURNED = CallState.RETURNED
-  YIELDED = 3
-  STREAM_READ = 4
-  STREAM_WRITE = 5
-  FUTURE_READ = 6
-  FUTURE_WRITE = 7
+  YIELDED = 4
+  STREAM_READ = 5
+  STREAM_WRITE = 6
+  FUTURE_READ = 7
+  FUTURE_WRITE = 8
 
 EventTuple = tuple[EventCode, int, int]
 GetEventCallback = Callable[[], EventTuple]
 
 class Waitable:
-  maybe_task: Optional[Task]
+  waitable_set: Optional[WaitableSet]
   get_event_cb: Optional[GetEventCallback]
   event: asyncio.Event
 
   def __init__(self):
-    self.maybe_task = None
+    self.waitable_set = None
     self.get_event_cb = None
     self.event = asyncio.Event()
 
   def set_event(self, get_event_cb: GetEventCallback):
-    assert(self.maybe_task)
-    self.maybe_task.maybe_has_event.set()
+    if self.waitable_set:
+      self.waitable_set.maybe_has_event.set()
     self.get_event_cb = get_event_cb
     self.event.set()
 
@@ -338,16 +341,44 @@ def get_event(self) -> EventTuple:
     self.event.clear()
     return get_event_cb()
 
-  def set_task(self, maybe_task):
-    if self.maybe_task:
-      self.maybe_task.waitables.remove(self)
-    self.maybe_task = maybe_task
-    if self.maybe_task:
-      self.maybe_task.waitables.append(self)
+  def set_waitable_set(self, s):
+    if self.waitable_set:
+      self.waitable_set.elems.remove(self)
+    self.waitable_set = s
+    if s:
+      s.elems.append(self)
+      if self.has_event():
+        s.maybe_has_event.set()
 
   def drop(self):
     assert(not self.has_event())
-    self.set_task(None)
+    self.set_waitable_set(None)
+
+class WaitableSet:
+  elems: list[Waitable]
+  maybe_has_event: asyncio.Event
+
+  def __init__(self):
+    self.elems = []
+    self.maybe_has_event = asyncio.Event()
+
+  async def wait(self) -> EventTuple:
+    while True:
+      await self.maybe_has_event.wait()
+      if (e := self.poll()):
+        return e
+
+  def poll(self) -> Optional[EventTuple]:
+    random.shuffle(self.elems)
+    for w in self.elems:
+      if w.has_event():
+        assert(self.maybe_has_event.is_set())
+        return w.get_event()
+    self.maybe_has_event.clear()
+    return None
+
+  def drop(self):
+    trap_if(len(self.elems) > 0)
 
 #### Task State
 
@@ -358,10 +389,11 @@ class Task:
   caller: Optional[Task]
   on_return: Optional[Callable]
   on_block: Callable[[Awaitable], Awaitable]
-  waitables: list[Waitables]
-  maybe_has_event: asyncio.Event
   num_subtasks: int
   num_borrows: int
+  context: list[int]
+
+  NUM_CONTEXT = 2
 
   def __init__(self, opts, inst, ft, caller, on_return, on_block):
     self.opts = opts
@@ -370,10 +402,9 @@ def __init__(self, opts, inst, ft, caller, on_return, on_block):
     self.caller = caller
     self.on_return = on_return
     self.on_block = on_block if on_block else Task.sync_on_block
-    self.waitables = []
-    self.maybe_has_event = asyncio.Event()
     self.num_subtasks = 0
     self.num_borrows = 0
+    self.context = [0] * Task.NUM_CONTEXT
 
   current = asyncio.Lock()
   async def sync_on_block(a: Awaitable):
@@ -432,28 +463,9 @@ def maybe_start_pending_task(self):
         self.inst.starting_pending_task = True
         pending_future.set_result(None)
 
-  async def wait(self, sync) -> EventTuple:
-    while True:
-      await self.wait_on(sync, self.maybe_has_event.wait())
-      if (e := self.maybe_next_event()):
-        return e
-
-  def maybe_next_event(self) -> Optional[EventTuple]:
-    random.shuffle(self.waitables)
-    for w in self.waitables:
-      if w.has_event():
-        assert(self.maybe_has_event.is_set())
-        return w.get_event()
-    self.maybe_has_event.clear()
-    return None
-
   async def yield_(self, sync):
     await self.wait_on(sync, asyncio.sleep(0))
 
-  async def poll(self, sync) -> Optional[EventTuple]:
-    await self.yield_(sync)
-    return self.maybe_next_event()
-
   async def call_sync(self, callee, *args):
     if self.inst.interruptible.is_set():
       self.inst.interruptible.clear()
@@ -497,7 +509,6 @@ def return_(self, flat_results):
   def exit(self):
     assert(Task.current.locked())
     trap_if(self.num_subtasks > 0)
-    trap_if(len(self.waitables) > 0)
     trap_if(self.on_return)
     assert(self.num_borrows == 0)
     trap_if(self.inst.num_tasks == 1 and self.inst.backpressure)
@@ -527,7 +538,6 @@ def add_waitable(self, task):
     self.supertask = task
     self.supertask.num_subtasks += 1
     Waitable.__init__(self)
-    Waitable.set_task(self, task)
     return task.inst.waitables.add(self)
 
   def add_lender(self, lending_handle):
@@ -1712,19 +1722,42 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_blo
       [] = await call_and_trap_on_throw(callee, task, flat_args)
       assert(types_match_values(flat_ft.results, []))
     else:
-      [packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
-      assert(types_match_values(flat_ft.results, [packed_ctx]))
-      while packed_ctx != 0:
-        is_yield = bool(packed_ctx & 1)
-        ctx = packed_ctx & ~1
-        if is_yield:
-          await task.yield_(sync = False)
-          event, p1, p2 = (EventCode.YIELDED, 0, 0)
+      [packed] = await call_and_trap_on_throw(callee, task, flat_args)
+      while True:
+        code,si = unpack_callback_result(packed)
+        if si != 0:
+          s = task.inst.waitable_sets.get(si)
+        match code:
+          case CallbackCode.EXIT:
+            break
+          case CallbackCode.WAIT:
+            e = await task.wait_on(opts.sync, s.wait())
+          case CallbackCode.POLL:
+            await task.yield_(opts.sync)
+            e = s.poll()
+          case CallbackCode.YIELD:
+            await task.yield_(opts.sync)
+            e = None
+        if e:
+          event, p1, p2 = e
         else:
-          event, p1, p2 = await task.wait(sync = False)
-        [packed_ctx] = await call_and_trap_on_throw(opts.callback, task, [ctx, event, p1, p2])
+          event, p1, p2 = (EventCode.YIELDED, 0, 0)
+        [packed] = await call_and_trap_on_throw(opts.callback, task, [event, p1, p2])
   task.exit()
 
+class CallbackCode(IntEnum):
+  EXIT = 0
+  WAIT = 1
+  POLL = 2
+  YIELD = 3
+
+def unpack_callback_result(packed):
+  code = packed & 3
+  assert(packed < 2**32)
+  assert(Table.MAX_LENGTH < 2**30)
+  waitable_set_index = packed >> 2
+  return (CallbackCode(code), waitable_set_index)
+
 async def call_and_trap_on_throw(callee, task, args):
   try:
     return await callee(task, args)
@@ -1826,6 +1859,20 @@ async def canon_resource_rep(rt, task, i):
   trap_if(h.rt is not rt)
   return [h.rep]
 
+### 🔀 `canon task.set-context`
+
+async def canon_task_set_context(i, task, v):
+  assert(types_match_values(['i32'], [v]))
+  trap_if(i >= Task.NUM_CONTEXT)
+  task.context[i] = v
+  return []
+
+### 🔀 `canon task.get-context`
+
+async def canon_task_get_context(i, task):
+  trap_if(i >= Task.NUM_CONTEXT)
+  return [task.context[i]]
+
 ### 🔀 `canon task.backpressure`
 
 async def canon_task_backpressure(task, flat_args):
@@ -1843,35 +1890,65 @@ async def canon_task_return(task, result_type, opts, flat_args):
   task.return_(flat_args)
   return []
 
-### 🔀 `canon task.wait`
+### 🔀 `canon task.yield`
+
+async def canon_task_yield(sync, task):
+  trap_if(not task.inst.may_leave)
+  trap_if(task.opts.callback and not sync)
+  await task.yield_(sync)
+  return []
+
+### 🔀 `canon waitable-set.new`
 
-async def canon_task_wait(sync, mem, task, ptr):
+async def canon_waitable_set_new(task):
+  trap_if(not task.inst.may_leave)
+  return [ task.inst.waitable_sets.add(WaitableSet()) ]
+
+### 🔀 `canon waitable-set.wait`
+
+async def canon_waitable_set_wait(sync, mem, task, si, ptr):
   trap_if(not task.inst.may_leave)
   trap_if(task.opts.callback and not sync)
-  event, p1, p2 = await task.wait(sync)
+  s = task.inst.waitable_sets.get(si)
+  e = await task.wait_on(sync, s.wait())
+  return unpack_event(mem, task, ptr, e)
+
+def unpack_event(mem, task, ptr, e: EventTuple):
+  event, p1, p2 = e
   cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
   store(cx, p1, U32Type(), ptr)
   store(cx, p2, U32Type(), ptr + 4)
   return [event]
 
-### 🔀 `canon task.poll`
+### 🔀 `canon waitable-set.poll`
 
-async def canon_task_poll(sync, mem, task, ptr):
+async def canon_waitable_set_poll(sync, mem, task, si, ptr):
   trap_if(not task.inst.may_leave)
   trap_if(task.opts.callback and not sync)
-  ret = await task.poll(sync)
-  if ret is None:
-    return [0]
-  cx = LiftLowerContext(CanonicalOptions(memory = mem), task.inst)
-  store(cx, ret, TupleType([U32Type(), U32Type(), U32Type()]), ptr)
-  return [1]
+  s = task.inst.waitable_sets.get(si)
+  await task.yield_(sync)
+  if (e := s.poll()):
+    return unpack_event(mem, task, ptr, e)
+  return [EventCode.NONE]
 
-### 🔀 `canon task.yield`
+### 🔀 `canon waitable-set.drop`
 
-async def canon_task_yield(sync, task):
+async def canon_waitable_set_drop(task, i):
   trap_if(not task.inst.may_leave)
-  trap_if(task.opts.callback and not sync)
-  await task.yield_(sync)
+  s = task.inst.waitable_sets.remove(i)
+  s.drop()
+  return []
+
+### 🔀 `canon waitable.set`
+
+async def canon_waitable_set(task, wi, si):
+  trap_if(not task.inst.may_leave)
+  w = task.inst.waitables.get(wi)
+  if si == 0:
+    w.set_waitable_set(None)
+  else:
+    s = task.inst.waitable_sets.get(si)
+    w.set_waitable_set(s)
   return []
 
 ### 🔀 `canon subtask.drop`
@@ -1938,7 +2015,6 @@ def on_async_partial(revoke_buffer):
     def copy_event(revoke_buffer):
       revoke_buffer()
       h.copying = False
-      h.set_task(None)
       return (event_code, i, pack_copy_result(task, buffer, h))
     def on_async_partial(revoke_buffer):
       h.set_event(partial(copy_event, revoke_buffer))
@@ -1946,7 +2022,6 @@ def on_async_done():
       h.set_event(partial(copy_event, revoke_buffer = lambda:()))
     if h.copy(buffer, on_async_partial, on_async_done) != 'done':
       h.copying = True
-      h.set_task(task)
       return [BLOCKED]
   return [pack_copy_result(task, buffer, h)]
 
diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py
index d9261b70..151ab826 100644
--- a/design/mvp/canonical-abi/run_tests.py
+++ b/design/mvp/canonical-abi/run_tests.py
@@ -554,40 +554,50 @@ async def core_blocking_producer(task, args):
     return []
   blocking_callee = partial(canon_lift, producer_opts, producer_inst, blocking_ft, core_blocking_producer)
 
-  consumer_heap = Heap(10)
+  consumer_heap = Heap(20)
   consumer_opts = mk_opts(consumer_heap.memory)
   consumer_opts.sync = False
 
   async def consumer(task, args):
     [b] = args
+    [seti] = await canon_waitable_set_new(task)
     ptr = consumer_heap.realloc(0, 0, 1, 1)
     [ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [ptr])
     assert(ret == 0)
     u8 = consumer_heap.memory[ptr]
     assert(u8 == 43)
     [ret] = await canon_lower(consumer_opts, toggle_ft, toggle_callee, task, [])
-    subi,state = unpack_lower_result(ret)
+    subi1,state = unpack_lower_result(ret)
+    assert(subi1 == 1)
     assert(state == CallState.STARTED)
+    [] = await canon_waitable_set(task, subi1, seti)
     retp = ptr
     consumer_heap.memory[retp] = 13
     [ret] = await canon_lower(consumer_opts, blocking_ft, blocking_callee, task, [83, retp])
-    assert(ret == (2 | (CallState.STARTING << 30)))
+    subi2,state = unpack_lower_result(ret)
+    assert(subi2 == 2)
+    assert(state == CallState.STARTING)
     assert(consumer_heap.memory[retp] == 13)
+    [] = await canon_waitable_set(task, subi2, seti)
     fut1.set_result(None)
-    event, callidx, _ = await task.wait(sync = False)
+
+    waitretp = consumer_heap.realloc(0, 0, 8, 4)
+    [event] = await canon_waitable_set_wait(False, consumer_heap.memory, task, seti, waitretp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 1)
-    [] = await canon_subtask_drop(task, callidx)
-    event, callidx, _ = await task.wait(sync = True)
+    assert(consumer_heap.memory[waitretp] == subi1)
+    [] = await canon_subtask_drop(task, subi1)
+
+    [event] = await canon_waitable_set_wait(True, consumer_heap.memory, task, seti, waitretp)
     assert(event == EventCode.CALL_STARTED)
-    assert(callidx == 2)
+    assert(consumer_heap.memory[waitretp] == subi2)
     assert(consumer_heap.memory[retp] == 13)
     fut2.set_result(None)
-    event, callidx, _ = await task.wait(sync = False)
+
+    [event] = await canon_waitable_set_wait(False, consumer_heap.memory, task, seti, waitretp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 2)
+    assert(consumer_heap.memory[waitretp] == subi2)
     assert(consumer_heap.memory[retp] == 44)
-    [] = await canon_subtask_drop(task, callidx)
+    [] = await canon_subtask_drop(task, subi2)
     fut3.set_result(None)
     assert(await task.on_block(fut4) == "done")
 
@@ -605,13 +615,19 @@ async def dtor(task, args):
     assert(i == 1)
     assert(dtor_value is None)
     [ret] = await canon_resource_drop(rt, False, task, 1)
-    assert(ret == (2 | (CallState.STARTED << 30)))
+    dtorsubi,state = unpack_lower_result(ret)
+    assert(dtorsubi == 2)
+    assert(state == CallState.STARTED)
     assert(dtor_value is None)
     dtor_fut.set_result(None)
-    event, callidx, _ = await task.wait(sync = False)
+
+    [] = await canon_waitable_set(task, dtorsubi, seti)
+    [event] = await canon_waitable_set_wait(False, consumer_heap.memory, task, seti, waitretp)
     assert(event == CallState.RETURNED)
-    assert(callidx == 2)
-    [] = await canon_subtask_drop(task, callidx)
+    assert(consumer_heap.memory[waitretp] == dtorsubi)
+    assert(dtor_value == 50)
+    [] = await canon_subtask_drop(task, dtorsubi)
+    [] = await canon_waitable_set_drop(task, seti)
 
     [] = await canon_task_return(task, [U8Type()], consumer_opts, [42])
     return []
@@ -655,36 +671,53 @@ async def consumer(task, args):
     assert(len(args) == 0)
 
     [ret] = await canon_lower(opts, producer_ft, producer1, task, [])
-    assert(ret == (1 | (CallState.STARTED << 30)))
+    subi1,state = unpack_lower_result(ret)
+    assert(subi1 == 1)
+    assert(state == CallState.STARTED)
 
     [ret] = await canon_lower(opts, producer_ft, producer2, task, [])
     assert(ret == (2 | (CallState.STARTED << 30)))
+    subi2,state = unpack_lower_result(ret)
+    assert(subi2 == 2)
+    assert(state == CallState.STARTED)
+
+    [seti] = await canon_waitable_set_new(task)
+    assert(seti == 1)
+    [] = await canon_waitable_set(task, subi1, seti)
+    [] = await canon_waitable_set(task, subi2, seti)
 
     fut1.set_result(None)
-    return [42]
+    [] = await canon_task_set_context(0, task, 42)
+    return [definitions.CallbackCode.WAIT|(seti << 2)]
 
   async def callback(task, args):
-    assert(len(args) == 4)
-    if args[0] == 42:
-      assert(args[1] == EventCode.CALL_RETURNED)
-      assert(args[2] == 1)
-      assert(args[3] == 0)
-      await canon_subtask_drop(task, 1)
-      return [53]
-    elif args[0] == 52:
-      assert(args[1] == EventCode.YIELDED)
-      assert(args[2] == 0)
-      assert(args[3] == 0)
-      fut2.set_result(None)
-      return [62]
-    else:
-      assert(args[0] == 62)
-      assert(args[1] == EventCode.CALL_RETURNED)
-      assert(args[2] == 2)
-      assert(args[3] == 0)
-      await canon_subtask_drop(task, 2)
-      [] = await canon_task_return(task, [U32Type()], opts, [83])
-      return [0]
+    assert(len(args) == 3)
+    seti = 1
+    [ctx] = await canon_task_get_context(0, task)
+    match ctx:
+      case 42:
+        assert(args[0] == EventCode.CALL_RETURNED)
+        assert(args[1] == 1)
+        assert(args[2] == 0)
+        await canon_subtask_drop(task, 1)
+        [] = await canon_task_set_context(0, task, 52)
+        return [definitions.CallbackCode.YIELD]
+      case 52:
+        assert(args[0] == EventCode.YIELDED)
+        assert(args[1] == 0)
+        assert(args[2] == 0)
+        fut2.set_result(None)
+        [] = await canon_task_set_context(0, task, 62)
+        return [definitions.CallbackCode.WAIT]
+      case 62:
+        assert(args[0] == EventCode.CALL_RETURNED)
+        assert(args[1] == 2)
+        assert(args[2] == 0)
+        await canon_subtask_drop(task, 2)
+        [] = await canon_task_return(task, [U32Type()], opts, [83])
+        return [definitions.CallbackCode.EXIT]
+      case _:
+        assert(False)
 
   consumer_inst = ComponentInstance()
   def on_start(): return []
@@ -727,7 +760,8 @@ async def producer2_core(task, args):
   producer1 = partial(canon_lift, producer_opts, producer_inst, producer_ft, producer1_core)
   producer2 = partial(canon_lift, producer_opts, producer_inst, producer_ft, producer2_core)
 
-  consumer_opts = mk_opts()
+  consumer_heap = Heap(20)
+  consumer_opts = mk_opts(consumer_heap.memory)
   consumer_opts.sync = False
 
   consumer_ft = FuncType([],[U8Type()])
@@ -735,31 +769,40 @@ async def consumer(task, args):
     assert(len(args) == 0)
 
     [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
-    assert(ret == (1 | (CallState.STARTED << 30)))
+    subi1,state = unpack_lower_result(ret)
+    assert(subi1 == 1)
+    assert(state == CallState.STARTED)
 
     [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
-    assert(ret == (2 | (CallState.STARTING << 30)))
+    subi2,state = unpack_lower_result(ret)
+    assert(subi2 == 2)
+    assert(state == CallState.STARTING)
 
-    assert(await task.poll(sync = False) is None)
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, subi1, seti)
+    [] = await canon_waitable_set(task, subi2, seti)
 
     fut.set_result(None)
     assert(producer1_done == False)
-    event, callidx, _ = await task.wait(sync = False)
+
+    retp = consumer_heap.realloc(0,0,8,4)
+    [event] = await canon_waitable_set_wait(False, consumer_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 1)
-    await canon_subtask_drop(task, callidx)
+    assert(consumer_heap.memory[retp] == subi1)
+    await canon_subtask_drop(task, subi1)
     assert(producer1_done == True)
 
     assert(producer2_done == False)
     await canon_task_yield(False, task)
     assert(producer2_done == True)
-    event, callidx, _ = await task.poll(sync = False)
+
+    [event] = await canon_waitable_set_poll(False, consumer_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 2)
-    await canon_subtask_drop(task, callidx)
+    assert(consumer_heap.memory[retp] == subi2)
+    await canon_subtask_drop(task, subi2)
     assert(producer2_done == True)
 
-    assert(await task.poll(sync = True) is None)
+    [] = await canon_waitable_set_drop(task, seti)
 
     await canon_task_return(task, [U8Type()], consumer_opts, [83])
     return []
@@ -804,37 +847,46 @@ async def producer2_core(task, args):
   producer1 = partial(canon_lift, producer_opts, producer_inst, producer_ft, producer1_core)
   producer2 = partial(canon_lift, producer_opts, producer_inst, producer_ft, producer2_core)
 
-  consumer_opts = CanonicalOptions()
-  consumer_opts.sync = False
+  consumer_heap = Heap(20)
+  consumer_opts = mk_opts(consumer_heap.memory, sync = False)
 
   consumer_ft = FuncType([],[U8Type()])
   async def consumer(task, args):
     assert(len(args) == 0)
 
     [ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
-    assert(ret == (1 | (CallState.STARTED << 30)))
+    subi1,state = unpack_lower_result(ret)
+    assert(subi1 == 1)
+    assert(state == CallState.STARTED)
 
     [ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
-    assert(ret == (2 | (CallState.STARTING << 30)))
+    subi2,state = unpack_lower_result(ret)
+    assert(subi2 == 2)
+    assert(state == CallState.STARTING)
 
-    assert(await task.poll(sync = False) is None)
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, subi1, seti)
+    [] = await canon_waitable_set(task, subi2, seti)
 
     fut.set_result(None)
     assert(producer1_done == False)
     assert(producer2_done == False)
-    event, callidx, _ = await task.wait(sync = False)
+
+    retp = consumer_heap.realloc(0,0,8,4)
+    [event] = await canon_waitable_set_wait(False, consumer_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 1)
+    assert(consumer_heap.memory[retp] == subi1)
     assert(producer1_done == True)
-    event, callidx, _ = await task.poll(sync = False)
+
+    [event] = await canon_waitable_set_poll(False, consumer_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 2)
+    assert(consumer_heap.memory[retp] == subi2)
     assert(producer2_done == True)
 
-    await canon_subtask_drop(task, 1)
-    await canon_subtask_drop(task, 2)
+    await canon_subtask_drop(task, subi1)
+    await canon_subtask_drop(task, subi2)
 
-    assert(await task.poll(sync = False) is None)
+    [] = await canon_waitable_set_drop(task, seti)
 
     await canon_task_return(task, [U8Type()], consumer_opts, [84])
     return []
@@ -868,26 +920,40 @@ async def core_hostcall_pre(fut, task, args):
   core_hostcall2 = partial(core_hostcall_pre, fut2)
   hostcall2 = partial(canon_lift, hostcall_opts, hostcall_inst, ft, core_hostcall2)
 
-  lower_opts = mk_opts()
+  lower_heap = Heap(20)
+  lower_opts = mk_opts(lower_heap.memory)
   lower_opts.sync = False
 
   async def core_func(task, args):
     [ret] = await canon_lower(lower_opts, ft, hostcall1, task, [])
-    assert(ret == (1 | (CallState.STARTED << 30)))
+    subi1,state = unpack_lower_result(ret)
+    assert(subi1 == 1)
+    assert(state == CallState.STARTED)
     [ret] = await canon_lower(lower_opts, ft, hostcall2, task, [])
-    assert(ret == (2 | (CallState.STARTED << 30)))
+    subi2,state = unpack_lower_result(ret)
+    assert(subi2 == 2)
+    assert(state == CallState.STARTED)
+
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, subi1, seti)
+    [] = await canon_waitable_set(task, subi2, seti)
 
     fut1.set_result(None)
-    event, callidx, _ = await task.wait(sync = False)
+
+    retp = lower_heap.realloc(0,0,8,4)
+    [event] = await canon_waitable_set_wait(False, lower_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 1)
+    assert(lower_heap.memory[retp] == subi1)
+
     fut2.set_result(None)
-    event, callidx, _ = await task.wait(sync = False)
+
+    [event] = await canon_waitable_set_wait(False, lower_heap.memory, task, seti, retp)
     assert(event == EventCode.CALL_RETURNED)
-    assert(callidx == 2)
+    assert(lower_heap.memory[retp] == subi2)
 
-    await canon_subtask_drop(task, 1)
-    await canon_subtask_drop(task, 2)
+    await canon_subtask_drop(task, subi1)
+    await canon_subtask_drop(task, subi2)
+    await canon_waitable_set_drop(task, seti)
 
     return []
 
@@ -1117,7 +1183,7 @@ async def core_func(task, args):
 async def test_async_stream_ops():
   ft = FuncType([StreamType(U8Type())], [StreamType(U8Type())])
   inst = ComponentInstance()
-  mem = bytearray(20)
+  mem = bytearray(24)
   opts = mk_opts(memory=mem, sync=False)
   sync_opts = mk_opts(memory=mem, sync=True)
 
@@ -1158,13 +1224,15 @@ async def core_func(task, args):
     [ret] = await canon_stream_read(U8Type(), opts, task, rsi1, 0, 4)
     assert(ret == definitions.BLOCKED)
     src_stream.write([1,2,3,4])
-    event, p1, p2 = await task.wait(sync = False)
+    retp = 16
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, rsi1, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, rsi1, retp)
     assert(event == EventCode.STREAM_READ)
-    assert(p1 == rsi1)
-    assert(p2 == 4)
+    assert(mem[retp+0] == rsi1)
+    assert(mem[retp+4] == 4)
     assert(mem[0:4] == b'\x01\x02\x03\x04')
     [wsi2] = await canon_stream_new(U8Type(), task)
-    retp = 16
     [ret] = await canon_lower(opts, ft, host_import, task, [wsi2, retp])
     assert(ret == 0)
     rsi2 = mem[16]
@@ -1172,19 +1240,21 @@ async def core_func(task, args):
     [ret] = await canon_stream_write(U8Type(), opts, task, wsi2, 0, 4)
     assert(ret == definitions.BLOCKED)
     host_import_incoming.set_remain(100)
-    event, p1, p2 = await task.wait(sync = False)
+    [] = await canon_waitable_set(task, wsi2, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_WRITE)
-    assert(p1 == wsi2)
-    assert(p2 == 4)
+    assert(mem[retp+0] == wsi2)
+    assert(mem[retp+4] == 4)
     [ret] = await canon_stream_read(U8Type(), sync_opts, task, rsi2, 0, 4)
     assert(ret == 4)
     [ret] = await canon_stream_write(U8Type(), opts, task, wsi1, 0, 4)
     assert(ret == definitions.BLOCKED)
     dst_stream.set_remain(100)
-    event, p1, p2 = await task.wait(sync = False)
+    [] = await canon_waitable_set(task, wsi1, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_WRITE)
-    assert(p1 == wsi1)
-    assert(p2 == 4)
+    assert(mem[retp+0] == wsi1)
+    assert(mem[retp+4] == 4)
     src_stream.write([5,6,7,8])
     src_stream.destroy_once_empty()
     [ret] = await canon_stream_read(U8Type(), opts, task, rsi1, 0, 4)
@@ -1198,16 +1268,18 @@ async def core_func(task, args):
     [] = await canon_stream_close_writable(U8Type(), task, wsi2, 0)
     [ret] = await canon_stream_read(U8Type(), opts, task, rsi2, 0, 4)
     assert(ret == definitions.BLOCKED)
-    event, p1, p2 = await task.wait(sync = False)
+    [] = await canon_waitable_set(task, rsi2, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_READ)
-    assert(p1 == rsi2)
-    assert(p2 == 4)
+    assert(mem[retp+0] == rsi2)
+    assert(mem[retp+4] == 4)
     [ret] = await canon_stream_read(U8Type(), opts, task, rsi2, 0, 4)
     assert(ret == definitions.CLOSED)
     [] = await canon_stream_close_readable(U8Type(), task, rsi2)
     [ret] = await canon_stream_write(U8Type(), sync_opts, task, wsi1, 0, 4)
     assert(ret == 4)
     [] = await canon_stream_close_writable(U8Type(), task, wsi1, 0)
+    [] = await canon_waitable_set_drop(task, seti)
     return []
 
   await canon_lift(opts, inst, ft, core_func, None, on_start, on_return)
@@ -1310,10 +1382,13 @@ async def core_func(task, args):
     [ret] = await canon_stream_read(U8Type(), opts, task, rsi, 0, 4)
     assert(ret == definitions.BLOCKED)
     src.write([5,6])
-    event, p1, p2 = await task.wait(sync = False)
+
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, rsi, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_READ)
-    assert(p1 == rsi)
-    assert(p2 == 2)
+    assert(mem[retp+0] == rsi)
+    assert(mem[retp+4] == 2)
     [] = await canon_stream_close_readable(U8Type(), task, rsi)
 
     [wsi] = await canon_stream_new(U8Type(), task)
@@ -1326,12 +1401,14 @@ async def core_func(task, args):
     [ret] = await canon_stream_write(U8Type(), opts, task, wsi, 2, 6)
     assert(ret == definitions.BLOCKED)
     dst.set_remain(4)
-    event, p1, p2 = await task.wait(sync = False)
+    [] = await canon_waitable_set(task, wsi, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_WRITE)
-    assert(p1 == wsi)
-    assert(p2 == 4)
+    assert(mem[retp+0] == wsi)
+    assert(mem[retp+4] == 4)
     assert(dst.received == [1,2,3,4,5,6])
     [] = await canon_stream_close_writable(U8Type(), task, wsi, 0)
+    [] = await canon_waitable_set_drop(task, seti)
     dst.set_remain(100)
     assert(await dst.consume(100) is None)
     return []
@@ -1348,7 +1425,7 @@ async def test_wasm_to_wasm_stream():
   fut1, fut2, fut3, fut4 = asyncio.Future(), asyncio.Future(), asyncio.Future(), asyncio.Future()
 
   inst1 = ComponentInstance()
-  mem1 = bytearray(10)
+  mem1 = bytearray(24)
   opts1 = mk_opts(memory=mem1, sync=False)
   ft1 = FuncType([], [StreamType(U8Type())])
   async def core_func1(task, args):
@@ -1372,22 +1449,26 @@ async def core_func1(task, args):
 
     fut3.set_result(None)
 
-    event, p1, p2 = await task.wait(sync = False)
+    retp = 16
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, wsi, seti)
+    [event] = await canon_waitable_set_wait(False, mem1, task, seti, retp)
     assert(event == EventCode.STREAM_WRITE)
-    assert(p1 == wsi)
-    assert(p2 == 4)
+    assert(mem1[retp+0] == wsi)
+    assert(mem1[retp+4] == 4)
 
     fut4.set_result(None)
 
     [errctxi] = await canon_error_context_new(opts1, task, 0, 0)
     [] = await canon_stream_close_writable(U8Type(), task, wsi, errctxi)
+    [] = await canon_waitable_set_drop(task, seti)
     [] = await canon_error_context_drop(task, errctxi)
     return []
 
   func1 = partial(canon_lift, opts1, inst1, ft1, core_func1)
 
   inst2 = ComponentInstance()
-  heap2 = Heap(10)
+  heap2 = Heap(24)
   mem2 = heap2.memory
   opts2 = mk_opts(memory=heap2.memory, realloc=heap2.realloc, sync=False)
   ft2 = FuncType([], [])
@@ -1395,10 +1476,10 @@ async def core_func2(task, args):
     assert(not args)
     [] = await canon_task_return(task, [], opts2, [])
 
-    retp = 0
+    retp = 16
     [ret] = await canon_lower(opts2, ft1, func1, task, [retp])
     assert(ret == 0)
-    rsi = mem2[0]
+    rsi = mem2[retp]
     assert(rsi == 1)
 
     [ret] = await canon_stream_read(U8Type(), opts2, task, rsi, 0, 8)
@@ -1406,10 +1487,12 @@ async def core_func2(task, args):
 
     fut1.set_result(None)
 
-    event, p1, p2 = await task.wait(sync = False)
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, rsi, seti)
+    [event] = await canon_waitable_set_wait(False, mem2, task, seti, retp)
     assert(event == EventCode.STREAM_READ)
-    assert(p1 == rsi)
-    assert(p2 == 4)
+    assert(mem2[retp+0] == rsi)
+    assert(mem2[retp+4] == 4)
     assert(mem2[0:8] == b'\x01\x02\x03\x04\x00\x00\x00\x00')
 
     fut2.set_result(None)
@@ -1429,6 +1512,7 @@ async def core_func2(task, args):
     errctxi = 1
     assert(ret == (definitions.CLOSED | errctxi))
     [] = await canon_stream_close_readable(U8Type(), task, rsi)
+    [] = await canon_waitable_set_drop(task, seti)
     [] = await canon_error_context_debug_message(opts2, task, errctxi, 0)
     [] = await canon_error_context_drop(task, errctxi)
     return []
@@ -1438,7 +1522,7 @@ async def core_func2(task, args):
 
 async def test_cancel_copy():
   inst = ComponentInstance()
-  mem = bytearray(10)
+  mem = bytearray(24)
   lower_opts = mk_opts(memory=mem, sync=False)
 
   host_ft1 = FuncType([StreamType(U8Type())],[])
@@ -1491,7 +1575,7 @@ async def core_func(task, args):
     host_sink.set_remain(100)
     assert(await host_sink.consume(100) is None)
 
-    retp = 0
+    retp = 16
     [ret] = await canon_lower(lower_opts, host_ft2, host_func2, task, [retp])
     assert(ret == 0)
     rsi = mem[retp]
@@ -1501,7 +1585,6 @@ async def core_func(task, args):
     assert(ret == 0)
     [] = await canon_stream_close_readable(U8Type(), task, rsi)
 
-    retp = 0
     [ret] = await canon_lower(lower_opts, host_ft2, host_func2, task, [retp])
     assert(ret == 0)
     rsi = mem[retp]
@@ -1512,12 +1595,15 @@ async def core_func(task, args):
     assert(ret == definitions.BLOCKED)
     host_source.write([7,8])
     await asyncio.sleep(0)
-    event,p1,p2 = await task.wait(sync = False)
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, rsi, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.STREAM_READ)
-    assert(p1 == rsi)
-    assert(p2 == 2)
+    assert(mem[retp+0] == rsi)
+    assert(mem[retp+4] == 2)
     assert(mem[0:2] == b'\x07\x08')
     [] = await canon_stream_close_readable(U8Type(), task, rsi)
+    [] = await canon_waitable_set_drop(task, seti)
 
     return []
 
@@ -1582,7 +1668,7 @@ def close(self, errctx = None):
 
 async def test_futures():
   inst = ComponentInstance()
-  mem = bytearray(10)
+  mem = bytearray(24)
   lower_opts = mk_opts(memory=mem, sync=False)
 
   host_ft1 = FuncType([FutureType(U8Type())],[FutureType(U8Type())])
@@ -1600,7 +1686,7 @@ async def host_func(task, on_start, on_return, on_block):
   async def core_func(task, args):
     assert(not args)
     [wfi] = await canon_future_new(U8Type(), task)
-    retp = 0
+    retp = 16
     [ret] = await canon_lower(lower_opts, host_ft1, host_func, task, [wfi, retp])
     assert(ret == 0)
     rfi = mem[retp]
@@ -1614,17 +1700,19 @@ async def core_func(task, args):
     [ret] = await canon_future_write(U8Type(), lower_opts, task, wfi, writep)
     assert(ret == 1)
 
-    event,p1,p2 = await task.wait(sync = False)
+    [seti] = await canon_waitable_set_new(task)
+    [] = await canon_waitable_set(task, rfi, seti)
+    [event] = await canon_waitable_set_wait(False, mem, task, seti, retp)
     assert(event == EventCode.FUTURE_READ)
-    assert(p1 == rfi)
-    assert(p2 == 1)
+    assert(mem[retp+0] == rfi)
+    assert(mem[retp+4] == 1)
     assert(mem[readp] == 43)
 
     [] = await canon_future_close_writable(U8Type(), task, wfi, 0)
     [] = await canon_future_close_readable(U8Type(), task, rfi)
+    [] = await canon_waitable_set_drop(task, seti)
 
     [wfi] = await canon_future_new(U8Type(), task)
-    retp = 0
     [ret] = await canon_lower(lower_opts, host_ft1, host_func, task, [wfi, retp])
     assert(ret == 0)
     rfi = mem[retp]