Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment: Only allow Async.Spawn to spawn runnable futures #46

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions jvm/src/main/scala/PosixLikeIO/PIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import java.nio.file.{Path, StandardOpenOption}
import scala.Tuple.Union
import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success, Try}
import gears.async.Scheduler
import java.util.concurrent.CancellationException

object File:
extension (resolver: Future.Resolver[Int])
Expand Down Expand Up @@ -88,6 +90,7 @@ class File(val path: String) {
}

class SocketUDP() {
import SocketUDP._
private var socket: Option[DatagramSocket] = None

def isOpened: Boolean = socket.isDefined && !socket.get.isClosed
Expand All @@ -110,17 +113,17 @@ class SocketUDP() {
def send(data: ByteBuffer, address: String, port: Int): Future[Unit] =
assert(socket.isDefined)

Async.blocking:
Future:
Future.withResolver: resolver =>
resolver.spawn:
val packet: DatagramPacket =
new DatagramPacket(data.array(), data.limit(), InetAddress.getByName(address), port)
socket.get.send(packet)

def receive(): Future[DatagramPacket] =
assert(socket.isDefined)

Async.blocking:
Future[DatagramPacket]:
Future.withResolver: resolver =>
resolver.spawn:
val buffer = Array.fill[Byte](10 * 1024)(0)
val packet: DatagramPacket = DatagramPacket(buffer, 10 * 1024)
socket.get.receive(packet)
Expand All @@ -133,6 +136,15 @@ class SocketUDP() {
}
}

object SocketUDP:
extension [T](resolver: Future.Resolver[T])
private[SocketUDP] inline def spawn(body: => T)(using s: Scheduler) =
s.execute(() =>
resolver.complete(Try(body).recover { case _: InterruptedException =>
throw CancellationException()
})
)

object PIOHelper {
def withFile[T](path: String, options: StandardOpenOption*)(f: File => T): T =
val file = File(path).open(options*)
Expand Down
19 changes: 15 additions & 4 deletions shared/src/main/scala/async/Async.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ trait Async(using val support: AsyncSupport, val scheduler: support.Scheduler):
def withGroup(group: CompletionGroup): Async

object Async:

private class Blocking(val group: CompletionGroup)(using support: AsyncSupport, scheduler: support.Scheduler)
extends Async(using support, scheduler):
private val lock = ReentrantLock()
Expand Down Expand Up @@ -51,19 +50,31 @@ object Async:
/** Execute asynchronous computation `body` on currently running thread. The thread will suspend when the computation
* waits.
*/
def blocking[T](body: Async ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T =
def blocking[T](body: Async.Spawn ?=> T)(using support: AsyncSupport, scheduler: support.Scheduler): T =
group(body)(using Blocking(CompletionGroup.Unlinked))

/** The currently executing Async context */
inline def current(using async: Async): Async = async

def group[T](body: Async ?=> T)(using async: Async): T =
/** [[Async.Spawn]] is a special subtype of [[Async]], also capable of spawning runnable [[Future]]s.
*
* Most functions should not take [[Spawn]] as a parameter, unless the function explicitly wants to spawn "dangling"
* runnable [[Future]]s. Instead, functions should take [[Async]] and spawn scoped futures within [[Async.group]].
*/
opaque type Spawn <: Async = Async

/** Runs [[body]] inside a spawnable context where it is allowed to spawning concurrently runnable [[Future]]s. When
* the body returns, all spawned futures are cancelled and waited for.
*/
def group[T](body: Async.Spawn ?=> T)(using Async): T =
withNewCompletionGroup(CompletionGroup().link())(body)

/** Runs a body within another completion group. When the body returns, the group is cancelled and its completion
* awaited with the `Unlinked` group.
*/
private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async ?=> T)(using async: Async): T =
private[async] def withNewCompletionGroup[T](group: CompletionGroup)(body: Async.Spawn ?=> T)(using
async: Async
): T =
val completionAsync =
if CompletionGroup.Unlinked == async.group
then async
Expand Down
14 changes: 8 additions & 6 deletions shared/src/main/scala/async/futures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ object Future:

/** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler`
*/
private class RunnableFuture[+T](body: Async ?=> T)(using ac: Async) extends CoreFuture[T]:
private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]:

private var innerGroup: CompletionGroup = CompletionGroup()

Expand Down Expand Up @@ -150,11 +150,10 @@ object Future:

end RunnableFuture

/** Create a future that asynchronously executes `body` that defines its result value in a Try or returns failure if
* an exception was thrown. If the future is created in an Async context, it is added to the children of that
* context's root.
/** Create a future that asynchronously executes [[body]] that defines its result value in a [[Try]] or returns
* [[Failure]] if an exception was thrown.
*/
def apply[T](body: Async ?=> T)(using Async): Future[T] =
def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn & async.type): Future[T] =
RunnableFuture(body)

/** A future that immediately terminates with the given result */
Expand Down Expand Up @@ -362,8 +361,11 @@ enum TaskSchedule:
*/
class Task[+T](val body: (Async, AsyncOperations) ?=> T):
natsukagami marked this conversation as resolved.
Show resolved Hide resolved

/** Run the current task and returns the result. */
def run()(using Async, AsyncOperations): T = body

/** Start a future computed from the `body` of this task */
def run(using Async, AsyncOperations) = Future(body)
def start()(using Async.Spawn, AsyncOperations) = Future(body)
natsukagami marked this conversation as resolved.
Show resolved Hide resolved

def schedule(s: TaskSchedule): Task[T] =
s match {
Expand Down
2 changes: 1 addition & 1 deletion shared/src/test/scala/CancellationBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CancellationBehavior extends munit.FunSuite:
state = State.Running(f)
case _ => fail(s"initializing failed, state is $state")

private def startFuture(info: Info, body: Async ?=> Unit = {})(using Async) =
private def startFuture(info: Info, body: Async ?=> Unit = {})(using a: Async, s: Async.Spawn & a.type) =
val f = Future:
info.run()
try
Expand Down
10 changes: 5 additions & 5 deletions shared/src/test/scala/TaskScheduleBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TaskScheduleBehavior extends munit.FunSuite {
var i = 0
val f = Task {
i += 1
}.schedule(TaskSchedule.Every(100, 3)).run
}.schedule(TaskSchedule.Every(100, 3)).start()
f.awaitResult
assertEquals(i, 3)
val end = System.currentTimeMillis()
Expand All @@ -29,7 +29,7 @@ class TaskScheduleBehavior extends munit.FunSuite {
var i = 0
val f = Task {
i += 1
}.schedule(TaskSchedule.ExponentialBackoff(50, 2, 5)).run
}.schedule(TaskSchedule.ExponentialBackoff(50, 2, 5)).start()
f.awaitResult
assertEquals(i, 5)
val end = System.currentTimeMillis()
Expand All @@ -43,7 +43,7 @@ class TaskScheduleBehavior extends munit.FunSuite {
var i = 0
val f = Task {
i += 1
}.schedule(TaskSchedule.FibonacciBackoff(10, 6)).run
}.schedule(TaskSchedule.FibonacciBackoff(10, 6)).start()
f.awaitResult
assertEquals(i, 6)
val end = System.currentTimeMillis()
Expand All @@ -61,7 +61,7 @@ class TaskScheduleBehavior extends munit.FunSuite {
Failure(AssertionError())
} else Success(i)
}
val ret = t.schedule(TaskSchedule.RepeatUntilSuccess(150)).run.awaitResult
val ret = t.schedule(TaskSchedule.RepeatUntilSuccess(150)).start().awaitResult
assertEquals(ret.get.get, 4)
val end = System.currentTimeMillis()
assert(end - start >= 4 * 150)
Expand All @@ -79,7 +79,7 @@ class TaskScheduleBehavior extends munit.FunSuite {
Success(i)
} else Failure(ex)
}
val ret = t.schedule(TaskSchedule.RepeatUntilFailure(150)).run.awaitResult
val ret = t.schedule(TaskSchedule.RepeatUntilFailure(150)).start().awaitResult
assertEquals(ret.get, Failure(ex))
val end = System.currentTimeMillis()
assert(end - start >= 4 * 150)
Expand Down
25 changes: 13 additions & 12 deletions shared/src/test/scala/TimerBehavior.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.concurrent.duration._
import scala.util.{Success, Failure}
import java.util.concurrent.TimeoutException
import java.util.concurrent.CancellationException
import scala.util.Try

class TimerBehavior extends munit.FunSuite {
import gears.async.default.given
Expand All @@ -23,17 +24,17 @@ class TimerBehavior extends munit.FunSuite {
assert(timer.src.awaitResult == timer.TimerEvent.Tick)
}

def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Future[T] =
val t = Future { sleep(d.toMillis) }
Future:
val g = Async.either(t, f).awaitResult
g match
case Left(_) =>
f.cancel()
throw TimeoutException()
case Right(v) =>
t.cancel()
v.get
def `cancel future after timeout`[T](d: Duration, f: Future[T])(using Async, AsyncOperations): Try[T] =
Async.group:
f.link()
val t = Future { sleep(d.toMillis) }
Try:
Async.select(
t handle: _ =>
throw TimeoutException(),
f handle: v =>
v.get
)

test("racing with a sleeping future") {
var touched = false
Expand All @@ -44,7 +45,7 @@ class TimerBehavior extends munit.FunSuite {
sleep(1000)
touched = true
)
assert(t.awaitResult.isFailure)
assert(t.isFailure)
assert(!touched)
sleep(2000)
assert(!touched)
Expand Down
Loading