Skip to content

Commit

Permalink
Remove inner classes and rewrite ChannelMultiplexer
Browse files Browse the repository at this point in the history
... so that it no longer creates "spooky actions at a distance"
  • Loading branch information
natsukagami committed Dec 8, 2023
1 parent d9c5d48 commit 8f41bc3
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 141 deletions.
54 changes: 45 additions & 9 deletions shared/src/main/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ object Async:
import scala.collection.JavaConverters._
val q = java.util.concurrent.ConcurrentLinkedQueue[T]()
q.addAll(values.asJavaCollection)
new Source[T]:
new Source[T]:
override def poll(k: Listener[T]): Boolean =
if q.isEmpty() then return false
else
Expand Down Expand Up @@ -174,25 +174,35 @@ object Async:
def dropListener(k: Listener[U]): Unit =
src.dropListener(transform(k))

def race[T](sources: Source[T]*) = raceImpl[T, T]((v, _) => v)(sources*)
def raceWithOrigin[T](sources: Source[T]*) =
raceImpl[(T, Source[T]), T]((v, src) => (v, src))(sources*)

/** Pass first result from any of `sources` to the continuation */
def race[T](sources: Source[T]*): Source[T] =
private def raceImpl[T, U](map: (U, Source[U]) => T)(sources: Source[U]*): Source[T] =
new Source[T] { selfSrc =>
def poll(k: Listener[T]): Boolean =
val it = sources.iterator
var found = false

while it.hasNext && !found do found = it.next.poll(k)
while it.hasNext && !found do
found = it.next.poll(
new Listener.ForwardingListener[U](this, k):
val lock = withLock(k) { inner => new ListenerLockWrapper(inner, selfSrc) }
def complete(data: U, source: Async.Source[U]) =
k.complete(map(data, source), selfSrc)
)
found

def onComplete(k: Listener[T]): Unit =
val listener = new Listener.ForwardingListener[T](this, k)
val listener = new Listener.ForwardingListener[U](this, k)
with NumberedLock
with Listener.ListenerLock
with Listener.PartialLock { self =>
val lock = self

var found = false
inline def heldLock = if k.lock == null then Listener.Locked else this
def heldLock = if k.lock == null then Listener.Locked else this

/* == PartialLock implementation == */
// Note that this is bogus if k.lock is null, but we'll never use it if it is.
Expand All @@ -219,21 +229,47 @@ object Async:
self.releaseLock()
if until == heldLock then null else k.lock

def complete(item: T, src: Async.Source[T]) =
def complete(item: U, src: Async.Source[U]) =
found = true
self.releaseLock()
sources.foreach(s => if s != src then s.dropListener(self))
k.complete(item, selfSrc)
k.complete(map(item, src), selfSrc)
} // end listener

sources.foreach(_.onComplete(listener))

def dropListener(k: Listener[T]): Unit =
val listener = Listener.ForwardingListener.empty(this, k)
val listener = Listener.ForwardingListener.empty[U](this, k)
sources.foreach(_.dropListener(listener))

}
end race
end raceImpl

/** Cases for handling async sources in a [[select]]. [[SelectCase]] can be constructed by extension methods `handle`
* and `handleVal` of [[Source]].
*/
opaque type SelectCase[T] = (Source[?], Nothing => T)
// ^ unsafe types, but we only construct SelectCase from `handle` and `handleVal` which are safe

extension [T](src: Source[T])
/** Attach a handler to [[src]], creating a [[SelectCase]]. */
inline def handle[U](f: T => U): SelectCase[U] = (src, f)

inline def ~~>[U](f: T => U): SelectCase[U] = src.handle(f)

// /** Attach a handler to [[src]] that takes a [[T]] and throws if [[Failure]] was returned from the source, creating
// * a [[SelectCase]].
// */
// inline def handleVal[U](f: T => U): SelectCase[U] = (src, t => f(t.get))

/** Race a list of sources with the corresponding handler functions, once an item has come back. Like [[race]],
* [[select]] guarantees exactly one of the sources are polled. Unlike `map`ping a [[Source]], the handler in
* [[select]] is run in the same async context as the calling context of [[select]].
*/
def select[T](cases: SelectCase[T]*)(using Async) =
val (input, which) = raceWithOrigin(cases.map(_._1)*).await
val (_, handler) = cases.find(_._1 == which).get
handler.asInstanceOf[input.type => T](input)

/** If left (respectively, right) source succeeds with `x`, pass `Left(x)`, (respectively, Right(x)) on to the
* continuation.
Expand Down
4 changes: 2 additions & 2 deletions shared/src/main/scala/async/Listener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ object Listener:
* [[Async.Source.dropListener]] these listeners are compared for equality by the hash of the source and the inner
* listener.
*/
abstract case class ForwardingListener[T](src: Async.Source[?], inner: Listener[T]) extends Listener[T]
abstract case class ForwardingListener[T](src: Async.Source[?], inner: Listener[?]) extends Listener[T]

object ForwardingListener:
/** Create an empty [[ForwardingListener]] for equality comparison. */
def empty[T](src: Async.Source[?], inner: Listener[T]) = new ForwardingListener(src, inner):
def empty[T](src: Async.Source[?], inner: Listener[?]) = new ForwardingListener[T](src, inner):
val lock = null
override def complete(data: T, source: Async.Source[T]) = ???

Expand Down
Loading

0 comments on commit 8f41bc3

Please sign in to comment.