Skip to content

Commit

Permalink
Experiment: Only allow Async.Spawnable to spawn runnable futures
Browse files Browse the repository at this point in the history
The goal is to disallow spawning dangling Futures from `using Async` functions.
`Async.Spawnable` is an opaque alias of `Async`, defined as a subtype of `Async`,
obtained by explicitly "upgrading" it through `Async.spawning` (which works similarly
to `Async.group`, so futures are all cancelled after) -- or automatically given
through `Async.blocking` or `Future.apply`.

The `Async.Spawnable`-taking functions (signalling usage of dangling futures)
should follow the hacky signature of `Future.apply`:

> def apply[T](body: Async.Spawnable ?=> T)
>   (using async: Async, spawn: Async.Spawnable & async.type): T

to ensure that the given `Async` instance (which is usually synthesized
to be the innermost context) is the same instance as the `Async.Spawnable`
instance. It happens quite often (especially when nesting `Async` contexts)
that these don't match:

> extension[T] (s: Seq[T])
>   def parallelMap[U](f: T => Async ?=> U)(using Async): Seq[U]
>
> Async.blocking: // Async.Spawnable here...
>   val seq = Seq(1, 2, 3, 4, 5)
>     .parallelMap: n => // Async here...
>        Async.select(
>           Future(doSomethingAsync(n)) handle Some(_),     // oops, spawned by Async.blocking
>           Future(Async.sleep(1.minute)) handle _ => None, // oops, spawned by Async.blocking
>        )
>   // oops, leaking all the futures...

with the `Future.apply` signature as above, this does not happen and will
give a compile time error.
  • Loading branch information
natsukagami committed Feb 27, 2024
1 parent 01b7d49 commit f084e16
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
21 changes: 18 additions & 3 deletions shared/src/main/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ trait Async(using val support: AsyncSupport, val scheduler: support.Scheduler):
def withGroup(group: CompletionGroup): Async

object Async:

private class Blocking(val group: CompletionGroup)(using support: AsyncSupport, scheduler: support.Scheduler)
extends Async(using support, scheduler):
private val lock = ReentrantLock()
Expand Down Expand Up @@ -51,19 +50,35 @@ object Async:
/** Execute asynchronous computation `body` on currently running thread. The thread will suspend when the computation
* waits.
*/
def blocking[T](body: Async ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T =
def blocking[T](body: Async.Spawnable ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T =
group(body)(using Blocking(CompletionGroup.Unlinked))

/** The currently executing Async context */
inline def current(using async: Async): Async = async

/** [[Async.Spawnable]] is a special subtype of [[Async]], also capable of spawning runnable [[Future]]s.
*
* Most functions should not take [[Spawnable]] as a parameter, unless the function explicitly wants to spawn
* "dangling" runnable [[Future]]s. Instead, functions should take [[Async]] and spawn scoped futures within
* [[Async.spawning]].
*/
opaque type Spawnable <: Async = Async

/** Runs [[body]] inside a spawnable context where it is allowed to spawning concurrently runnable [[Future]]s. When
* the body returns, all spawned futures are cancelled and waited for (similar to [[Async.group]]).
*/
inline def spawning[T](inline body: Async.Spawnable ?=> T)(using Async): T =
Async.group(body)

def group[T](body: Async ?=> T)(using async: Async): T =
withNewCompletionGroup(CompletionGroup().link())(body)

/** Runs a body within another completion group. When the body returns, the group is cancelled and its completion
* awaited with the `Unlinked` group.
*/
private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async ?=> T)(using async: Async): T =
private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async.Spawnable ?=> T)(using
async: Async
): T =
val completionAsync =
if CompletionGroup.Unlinked == async.group
then async
Expand Down
11 changes: 5 additions & 6 deletions shared/src/main/scala/async/futures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object Future:

/** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler`
*/
private class RunnableFuture[+T](body: Async ?=> T)(using ac: Async) extends CoreFuture[T]:
private class RunnableFuture[+T](body: Async.Spawnable ?=> T)(using ac: Async) extends CoreFuture[T]:

private var innerGroup: CompletionGroup = CompletionGroup()

Expand Down Expand Up @@ -150,11 +150,10 @@ object Future:

end RunnableFuture

/** Create a future that asynchronously executes `body` that defines its result value in a Try or returns failure if
* an exception was thrown. If the future is created in an Async context, it is added to the children of that
* context's root.
/** Create a future that asynchronously executes [[body]] that defines its result value in a [[Try]] or returns
* [[Failure]] if an exception was thrown.
*/
def apply[T](body: Async ?=> T)(using Async): Future[T] =
def apply[T](body: Async.Spawnable ?=> T)(using async: Async, spawnable: Async.Spawnable & async.type): Future[T] =
RunnableFuture(body)

/** A future that immediately terminates with the given result */
Expand Down Expand Up @@ -363,7 +362,7 @@ enum TaskSchedule:
class Task[+T](val body: (Async, AsyncOperations) ?=> T):

/** Start a future computed from the `body` of this task */
def run(using Async, AsyncOperations) = Future(body)
def run(using Async.Spawnable, AsyncOperations) = Future(body)

def schedule(s: TaskSchedule): Task[T] =
s match {
Expand Down
12 changes: 6 additions & 6 deletions shared/src/test/scala/CancellationBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CancellationBehavior extends munit.FunSuite:
state = State.Running(f)
case _ => fail(s"initializing failed, state is $state")

private def startFuture(info: Info, body: Async ?=> Unit = {})(using Async) =
private def startFuture(info: Info, body: Async ?=> Unit = {})(using a: Async, s: Async.Spawnable & a.type) =
val f = Future:
info.run()
try
Expand Down Expand Up @@ -65,7 +65,7 @@ class CancellationBehavior extends munit.FunSuite:
test("group cancel"):
var x = 0
Async.blocking:
Async.group:
Async.spawning:
Future:
sleep(400)
x = 1
Expand All @@ -75,7 +75,7 @@ class CancellationBehavior extends munit.FunSuite:
val info = Info()
Async.blocking:
val promise = Future.Promise[Unit]()
Async.group:
Async.spawning:
startFuture(info, promise.complete(Success(())))
promise.await
info.assertCancelled()
Expand All @@ -84,10 +84,10 @@ class CancellationBehavior extends munit.FunSuite:
val (info1, info2) = (Info(), Info())
val (promise1, promise2) = (Future.Promise[Unit](), Future.Promise[Unit]())
Async.blocking:
Async.group:
Async.spawning:
startFuture(
info1, {
Async.group:
Async.spawning:
startFuture(info2, promise2.complete(Success(())))
promise2.await
info2.assertCancelled()
Expand Down Expand Up @@ -120,7 +120,7 @@ class CancellationBehavior extends munit.FunSuite:
val info = Info()
Async.blocking:
val promise = Future.Promise[Unit]()
Async.group:
Async.spawning:
Async.current.group.cancel() // cancel now
val f = startFuture(info, promise.complete(Success(())))
promise.awaitResult
Expand Down
25 changes: 13 additions & 12 deletions shared/src/test/scala/TimerBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.concurrent.duration._
import scala.util.{Success, Failure}
import java.util.concurrent.TimeoutException
import java.util.concurrent.CancellationException
import scala.util.Try

class TimerBehavior extends munit.FunSuite {
import gears.async.default.given
Expand All @@ -23,17 +24,17 @@ class TimerBehavior extends munit.FunSuite {
assert(timer.src.awaitResult == timer.TimerEvent.Tick)
}

def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Future[T] =
val t = Future { sleep(d.toMillis) }
Future:
val g = Async.either(t, f).awaitResult
g match
case Left(_) =>
f.cancel()
throw TimeoutException()
case Right(v) =>
t.cancel()
v.get
def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Try[T] =
Async.spawning:
f.link()
val t = Future { sleep(d.toMillis) }
Try:
Async.select(
t handle: _ =>
throw TimeoutException(),
f handle: v =>
v.get
)

test("racing with a sleeping future") {
var touched = false
Expand All @@ -44,7 +45,7 @@ class TimerBehavior extends munit.FunSuite {
sleep(1000)
touched = true
)
assert(t.awaitResult.isFailure)
assert(t.isFailure)
assert(!touched)
sleep(2000)
assert(!touched)
Expand Down

0 comments on commit f084e16

Please sign in to comment.