diff --git a/jvm/src/main/scala/PosixLikeIO/PIO.scala b/jvm/src/main/scala/PosixLikeIO/PIO.scala index 83bb5fec..edca40d5 100644 --- a/jvm/src/main/scala/PosixLikeIO/PIO.scala +++ b/jvm/src/main/scala/PosixLikeIO/PIO.scala @@ -12,6 +12,8 @@ import java.nio.file.{Path, StandardOpenOption} import scala.Tuple.Union import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} +import gears.async.Scheduler +import java.util.concurrent.CancellationException object File: extension (resolver: Future.Resolver[Int]) @@ -88,6 +90,7 @@ class File(val path: String) { } class SocketUDP() { + import SocketUDP._ private var socket: Option[DatagramSocket] = None def isOpened: Boolean = socket.isDefined && !socket.get.isClosed @@ -110,8 +113,8 @@ class SocketUDP() { def send(data: ByteBuffer, address: String, port: Int): Future[Unit] = assert(socket.isDefined) - Async.blocking: - Future: + Future.withResolver: resolver => + resolver.spawn: val packet: DatagramPacket = new DatagramPacket(data.array(), data.limit(), InetAddress.getByName(address), port) socket.get.send(packet) @@ -119,8 +122,8 @@ class SocketUDP() { def receive(): Future[DatagramPacket] = assert(socket.isDefined) - Async.blocking: - Future[DatagramPacket]: + Future.withResolver: resolver => + resolver.spawn: val buffer = Array.fill[Byte](10 * 1024)(0) val packet: DatagramPacket = DatagramPacket(buffer, 10 * 1024) socket.get.receive(packet) @@ -133,6 +136,15 @@ class SocketUDP() { } } +object SocketUDP: + extension [T](resolver: Future.Resolver[T]) + private[SocketUDP] inline def spawn(body: => T)(using s: Scheduler) = + s.execute(() => + resolver.complete(Try(body).recover { case _: InterruptedException => + throw CancellationException() + }) + ) + object PIOHelper { def withFile[T](path: String, options: StandardOpenOption*)(f: File => T): T = val file = File(path).open(options*) diff --git a/shared/src/main/scala/async/Async.scala b/shared/src/main/scala/async/Async.scala index 2dc2d90d..5ae129da 100644 --- a/shared/src/main/scala/async/Async.scala +++ b/shared/src/main/scala/async/Async.scala @@ -20,7 +20,6 @@ trait Async(using val support: AsyncSupport, val scheduler: support.Scheduler): def withGroup(group: CompletionGroup): Async object Async: - private class Blocking(val group: CompletionGroup)(using support: AsyncSupport, scheduler: support.Scheduler) extends Async(using support, scheduler): private val lock = ReentrantLock() @@ -51,19 +50,31 @@ object Async: /** Execute asynchronous computation `body` on currently running thread. The thread will suspend when the computation * waits. */ - def blocking[T](body: Async ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T = + def blocking[T](body: Async.Spawn ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T = group(body)(using Blocking(CompletionGroup.Unlinked)) /** The currently executing Async context */ inline def current(using async: Async): Async = async - def group[T](body: Async ?=> T)(using async: Async): T = + /** [[Async.Spawn]] is a special subtype of [[Async]], also capable of spawning runnable [[Future]]s. + * + * Most functions should not take [[Spawn]] as a parameter, unless the function explicitly wants to spawn "dangling" + * runnable [[Future]]s. Instead, functions should take [[Async]] and spawn scoped futures within [[Async.group]]. + */ + opaque type Spawn <: Async = Async + + /** Runs [[body]] inside a spawnable context where it is allowed to spawning concurrently runnable [[Future]]s. When + * the body returns, all spawned futures are cancelled and waited for. + */ + def group[T](body: Async.Spawn ?=> T)(using Async): T = withNewCompletionGroup(CompletionGroup().link())(body) /** Runs a body within another completion group. When the body returns, the group is cancelled and its completion * awaited with the `Unlinked` group. */ - private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async ?=> T)(using async: Async): T = + private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async.Spawn ?=> T)(using + async: Async + ): T = val completionAsync = if CompletionGroup.Unlinked == async.group then async diff --git a/shared/src/main/scala/async/futures.scala b/shared/src/main/scala/async/futures.scala index e2b98f74..0babec88 100644 --- a/shared/src/main/scala/async/futures.scala +++ b/shared/src/main/scala/async/futures.scala @@ -87,7 +87,7 @@ object Future: /** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler` */ - private class RunnableFuture[+T](body: Async ?=> T)(using ac: Async) extends CoreFuture[T]: + private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]: private var innerGroup: CompletionGroup = CompletionGroup() @@ -150,11 +150,10 @@ object Future: end RunnableFuture - /** Create a future that asynchronously executes `body` that defines its result value in a Try or returns failure if - * an exception was thrown. If the future is created in an Async context, it is added to the children of that - * context's root. + /** Create a future that asynchronously executes [[body]] that defines its result value in a [[Try]] or returns + * [[Failure]] if an exception was thrown. */ - def apply[T](body: Async ?=> T)(using Async): Future[T] = + def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn & async.type): Future[T] = RunnableFuture(body) /** A future that immediately terminates with the given result */ @@ -362,8 +361,12 @@ enum TaskSchedule: */ class Task[+T](val body: (Async, AsyncOperations) ?=> T): + /** Run the current task and returns the result. */ + def run()(using Async, AsyncOperations): T = body + /** Start a future computed from the `body` of this task */ - def run(using Async, AsyncOperations) = Future(body) + def start()(using async: Async, spawn: Async.Spawn & async.type, asyncOps: AsyncOperations) = + Future(body)(using async, spawn) def schedule(s: TaskSchedule): Task[T] = s match { diff --git a/shared/src/test/scala/CancellationBehavior.scala b/shared/src/test/scala/CancellationBehavior.scala index 4df7e435..d9f6d233 100644 --- a/shared/src/test/scala/CancellationBehavior.scala +++ b/shared/src/test/scala/CancellationBehavior.scala @@ -37,7 +37,7 @@ class CancellationBehavior extends munit.FunSuite: state = State.Running(f) case _ => fail(s"initializing failed, state is $state") - private def startFuture(info: Info, body: Async ?=> Unit = {})(using Async) = + private def startFuture(info: Info, body: Async ?=> Unit = {})(using a: Async, s: Async.Spawn & a.type) = val f = Future: info.run() try diff --git a/shared/src/test/scala/TaskScheduleBehavior.scala b/shared/src/test/scala/TaskScheduleBehavior.scala index e7c2b71d..5247af49 100644 --- a/shared/src/test/scala/TaskScheduleBehavior.scala +++ b/shared/src/test/scala/TaskScheduleBehavior.scala @@ -15,7 +15,7 @@ class TaskScheduleBehavior extends munit.FunSuite { var i = 0 val f = Task { i += 1 - }.schedule(TaskSchedule.Every(100, 3)).run + }.schedule(TaskSchedule.Every(100, 3)).start() f.awaitResult assertEquals(i, 3) val end = System.currentTimeMillis() @@ -29,7 +29,7 @@ class TaskScheduleBehavior extends munit.FunSuite { var i = 0 val f = Task { i += 1 - }.schedule(TaskSchedule.ExponentialBackoff(50, 2, 5)).run + }.schedule(TaskSchedule.ExponentialBackoff(50, 2, 5)).start() f.awaitResult assertEquals(i, 5) val end = System.currentTimeMillis() @@ -43,7 +43,7 @@ class TaskScheduleBehavior extends munit.FunSuite { var i = 0 val f = Task { i += 1 - }.schedule(TaskSchedule.FibonacciBackoff(10, 6)).run + }.schedule(TaskSchedule.FibonacciBackoff(10, 6)).start() f.awaitResult assertEquals(i, 6) val end = System.currentTimeMillis() @@ -61,7 +61,7 @@ class TaskScheduleBehavior extends munit.FunSuite { Failure(AssertionError()) } else Success(i) } - val ret = t.schedule(TaskSchedule.RepeatUntilSuccess(150)).run.awaitResult + val ret = t.schedule(TaskSchedule.RepeatUntilSuccess(150)).start().awaitResult assertEquals(ret.get.get, 4) val end = System.currentTimeMillis() assert(end - start >= 4 * 150) @@ -79,7 +79,7 @@ class TaskScheduleBehavior extends munit.FunSuite { Success(i) } else Failure(ex) } - val ret = t.schedule(TaskSchedule.RepeatUntilFailure(150)).run.awaitResult + val ret = t.schedule(TaskSchedule.RepeatUntilFailure(150)).start().awaitResult assertEquals(ret.get, Failure(ex)) val end = System.currentTimeMillis() assert(end - start >= 4 * 150) diff --git a/shared/src/test/scala/TimerBehavior.scala b/shared/src/test/scala/TimerBehavior.scala index d124833e..5a35ba35 100644 --- a/shared/src/test/scala/TimerBehavior.scala +++ b/shared/src/test/scala/TimerBehavior.scala @@ -4,6 +4,7 @@ import scala.concurrent.duration._ import scala.util.{Success, Failure} import java.util.concurrent.TimeoutException import java.util.concurrent.CancellationException +import scala.util.Try class TimerBehavior extends munit.FunSuite { import gears.async.default.given @@ -23,17 +24,17 @@ class TimerBehavior extends munit.FunSuite { assert(timer.src.awaitResult == timer.TimerEvent.Tick) } - def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Future[T] = - val t = Future { sleep(d.toMillis) } - Future: - val g = Async.either(t, f).awaitResult - g match - case Left(_) => - f.cancel() - throw TimeoutException() - case Right(v) => - t.cancel() - v.get + def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Try[T] = + Async.group: + f.link() + val t = Future { sleep(d.toMillis) } + Try: + Async.select( + t handle: _ => + throw TimeoutException(), + f handle: v => + v.get + ) test("racing with a sleeping future") { var touched = false @@ -44,7 +45,7 @@ class TimerBehavior extends munit.FunSuite { sleep(1000) touched = true ) - assert(t.awaitResult.isFailure) + assert(t.isFailure) assert(!touched) sleep(2000) assert(!touched)