From 6eef516b8a82b369d05023bd6d9469ebf55e6ce7 Mon Sep 17 00:00:00 2001 From: Simon Vergauwen Date: Mon, 26 Jul 2021 17:43:22 +0200 Subject: [PATCH] Flow parMap & parMapUnordered (#2453) --- .../kotlin/arrow/fx/coroutines/predef-test.kt | 6 + .../kotlin/arrow/fx/coroutines/flow.kt | 102 ++++++++ .../kotlin/arrow/fx/coroutines/FlowTest.kt | 222 +++++++++++++++++- .../kotlin/arrow/fx/coroutines/FlowJvmTest.kt | 20 ++ 4 files changed, 346 insertions(+), 4 deletions(-) diff --git a/arrow-libs/fx/arrow-fx-coroutines-test/src/commonMain/kotlin/arrow/fx/coroutines/predef-test.kt b/arrow-libs/fx/arrow-fx-coroutines-test/src/commonMain/kotlin/arrow/fx/coroutines/predef-test.kt index 54eac08d0da..26d93567cd2 100644 --- a/arrow-libs/fx/arrow-fx-coroutines-test/src/commonMain/kotlin/arrow/fx/coroutines/predef-test.kt +++ b/arrow-libs/fx/arrow-fx-coroutines-test/src/commonMain/kotlin/arrow/fx/coroutines/predef-test.kt @@ -20,10 +20,13 @@ import io.kotest.property.arbitrary.char import io.kotest.property.arbitrary.choice import io.kotest.property.arbitrary.constant import io.kotest.property.arbitrary.int +import io.kotest.property.arbitrary.list import io.kotest.property.arbitrary.long import io.kotest.property.arbitrary.map import io.kotest.property.arbitrary.string import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow import kotlin.coroutines.Continuation import kotlin.coroutines.CoroutineContext import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED @@ -38,6 +41,9 @@ public data class SideEffect(var counter: Int = 0) { } } +public fun Arb.Companion.flow(arbA: Arb): Arb> = + Arb.list(arbA).map { it.asFlow() } + public fun Arb.Companion.throwable(): Arb = Arb.string().map(::RuntimeException) diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/flow.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/flow.kt index 4f899ece410..952de1c6241 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/flow.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/flow.kt @@ -3,11 +3,25 @@ package arrow.fx.coroutines +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.FlowPreview import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.DEFAULT_CONCURRENCY import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.buffer +import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flattenMerge import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.retryWhen +import kotlin.coroutines.CoroutineContext import kotlin.jvm.JvmMultifileClass import kotlin.jvm.JvmName @@ -57,3 +71,91 @@ public fun Flow.retry(schedule: Schedule): Flow = flo emit(it) } } + +/** + * Like [Flow.map], but will evaluate effects in parallel, emitting the results + * downstream in the same order as the input stream. The number of concurrent effects + * is limited by [concurrency]. + * + * See [parMapUnordered] if there is no requirement to retain the order of the original stream. + * + * ```kotlin:ank:playground + * import kotlinx.coroutines.delay + * import kotlinx.coroutines.flow.flowOf + * import kotlinx.coroutines.flow.toList + * import kotlinx.coroutines.flow.collect + * import arrow.fx.coroutines.parMap + * + * //sampleStart + * suspend fun main(): Unit { + * flowOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + * .parMap { a -> + * delay(100) + * a + * }.toList() // [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + * } + * //sampleEnd + * ``` + */ +@FlowPreview +@ExperimentalCoroutinesApi +public inline fun Flow.parMap( + context: CoroutineContext = Dispatchers.Default, + concurrency: Int = DEFAULT_CONCURRENCY, + crossinline transform: suspend CoroutineScope.(a: A) -> B +): Flow = + channelFlow> { + map { a -> + val deferred = CompletableDeferred() + send(deferred) + flow { + try { + val b = transform(a) + deferred.complete(b) + } catch (e: Throwable) { + require(deferred.completeExceptionally(e)) + throw e + } + }.flowOn(context) + } + .flattenMerge(concurrency) + .launchIn(this) + } + .buffer(concurrency) + .map(Deferred::await) + +/** + * Like [map], but will evaluate effects in parallel, emitting the results downstream. + * The number of concurrent effects is limited by [concurrency]. + * + * See [parMap] if retaining the original order of the stream is required. + * + * ```kotlin:ank:playground + * import kotlinx.coroutines.delay + * import kotlinx.coroutines.flow.flowOf + * import kotlinx.coroutines.flow.toList + * import kotlinx.coroutines.flow.collect + * import arrow.fx.coroutines.parMapUnordered + * + * //sampleStart + * suspend fun main(): Unit { + * flowOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + * .parMapUnordered { a -> + * delay(100) + * a + * }.toList() // [3, 5, 4, 6, 2, 8, 7, 1, 9, 10] + * } + * //sampleEnd + * ``` + */ +@FlowPreview +public inline fun Flow.parMapUnordered( + ctx: CoroutineContext = Dispatchers.Default, + concurrency: Int = DEFAULT_CONCURRENCY, + crossinline transform: suspend (a: A) -> B +): Flow = + map { o -> + flow { + emit(transform(o)) + }.flowOn(ctx) + }.flattenMerge(concurrency) diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/FlowTest.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/FlowTest.kt index 129d0f25994..e3ecd82e0c3 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/FlowTest.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/FlowTest.kt @@ -1,13 +1,20 @@ package arrow.fx.coroutines -import io.kotest.assertions.throwables.shouldThrow +import io.kotest.assertions.fail import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeTypeOf import io.kotest.property.Arb import io.kotest.property.arbitrary.int import io.kotest.property.arbitrary.positiveInts +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.reduce +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.toSet +import kotlinx.coroutines.launch import kotlin.time.ExperimentalTime @ExperimentalTime @@ -15,16 +22,18 @@ class FlowTest : ArrowFxSpec( spec = { "Retry - flow fails" { + val bang = RuntimeException("Bang!") + checkAll(Arb.int(), Arb.positiveInts(10)) { a, n -> var counter = 0 - val e = shouldThrow { + val e = assertThrowable { flow { emit(a) - if (++counter <= 11) throw RuntimeException("Bang!") + if (++counter <= 11) throw bang }.retry(Schedule.recurs(n)) .collect() } - e.message shouldBe "Bang!" + e shouldBe bang } } @@ -40,5 +49,210 @@ class FlowTest : ArrowFxSpec( sum shouldBe a * 6 } } + + "parMap - concurrency = 1 - identity" { + checkAll(Arb.flow(Arb.int())) { flow -> + flow.parMap(concurrency = 1) { it } + .toList() shouldBe flow.toList() + } + } + + "parMap - runs in parallel" { + checkAll(Arb.int(), Arb.int(1..2)) { i, n -> + val latch = CompletableDeferred() + flowOf(1, 2).parMap { index -> + if (index == n) latch.await() + else { + latch.complete(i) + null + } + }.toList().filterNotNull() shouldBe listOf(i) + } + } + + "parMap - triggers cancel signal" { + checkAll(Arb.int(), Arb.int(1..2)) { i, n -> + val latch = CompletableDeferred() + val exit = CompletableDeferred>() + + assertThrowable { + flowOf(1, 2).parMap { index -> + if (index == n) { + guaranteeCase({ + latch.complete(Unit) + never() + }, { ex -> exit.complete(Pair(i, ex)) }) + } else { + latch.await() + throw CancellationException(null, null) + } + }.collect() + fail("Cannot reach here. CancellationException should be thrown.") + }.shouldBeTypeOf() + + val (ii, ex) = exit.await() + ii shouldBe i + ex.shouldBeTypeOf() + } + } + + "parMap - exception in parMap cancels all running tasks" { + checkAll(Arb.int(), Arb.throwable(), Arb.int(1..2)) { i, e, n -> + val latch = CompletableDeferred() + val exit = CompletableDeferred>() + + assertThrowable { + flowOf(1, 2).parMap { index -> + if (index == n) { + guaranteeCase({ + latch.complete(Unit) + never() + }, { ex -> exit.complete(Pair(i, ex)) }) + } else { + latch.await() + throw e + } + }.collect() + fail("Cannot reach here. $e should be thrown.") + } shouldBe e + + val (ii, ex) = exit.await() + ii shouldBe i + ex.shouldBeTypeOf() + } + } + + "parMap - Cancelling parMap cancels all running jobs" { + checkAll(Arb.int(), Arb.int()) { i, i2 -> + val latch = CompletableDeferred() + val exitA = CompletableDeferred>() + val exitB = CompletableDeferred>() + + val job = launch { + flowOf(1, 2).parMap { index -> + guaranteeCase({ + if (index == 2) latch.complete(Unit) + never() + }, { ex -> + if (index == 1) exitA.complete(Pair(i, ex)) + else exitB.complete(Pair(i2, ex)) + }) + }.collect() + } + + latch.await() + job.cancel() + + val (ii, ex) = exitA.await() + ii shouldBe i + ex.shouldBeTypeOf() + + val (ii2, ex2) = exitB.await() + ii2 shouldBe i2 + ex2.shouldBeTypeOf() + } + } + "parMapUnordered - concurrency = 1 - identity" { + checkAll(Arb.flow(Arb.int())) { flow -> + flow.parMapUnordered(concurrency = 1) { it } + .toSet() shouldBe flow.toSet() + } + } + + "parMapUnordered - runs in parallel" { + checkAll(Arb.int(), Arb.int(1..2)) { i, n -> + val latch = CompletableDeferred() + flowOf(1, 2).parMapUnordered { index -> + if (index == n) latch.await() + else { + latch.complete(i) + null + } + }.toSet().filterNotNull() shouldBe setOf(i) + } + } + + "parMapUnordered - triggers cancel signal" { + checkAll(Arb.int(), Arb.int(1..2)) { i, n -> + val latch = CompletableDeferred() + val exit = CompletableDeferred>() + + assertThrowable { + flowOf(1, 2).parMapUnordered { index -> + if (index == n) { + guaranteeCase({ + latch.complete(Unit) + never() + }, { ex -> exit.complete(Pair(i, ex)) }) + } else { + latch.await() + throw CancellationException(null, null) + } + }.collect() + fail("Cannot reach here. CancellationException should be thrown.") + }.shouldBeTypeOf() + + val (ii, ex) = exit.await() + ii shouldBe i + ex.shouldBeTypeOf() + } + } + + "parMapUnordered - exception in parMap cancels all running tasks" { + checkAll(Arb.int(), Arb.throwable(), Arb.int(1..2)) { i, e, n -> + val latch = CompletableDeferred() + val exit = CompletableDeferred>() + + assertThrowable { + flowOf(1, 2).parMapUnordered { index -> + if (index == n) { + guaranteeCase({ + latch.complete(Unit) + never() + }, { ex -> exit.complete(Pair(i, ex)) }) + } else { + latch.await() + throw e + } + }.collect() + fail("Cannot reach here. $e should be thrown.") + } shouldBe e + + val (ii, ex) = exit.await() + ii shouldBe i + ex.shouldBeTypeOf() + } + } + + "parMapUnordered - Cancelling parMap cancels all running jobs" { + checkAll(Arb.int(), Arb.int()) { i, i2 -> + val latch = CompletableDeferred() + val exitA = CompletableDeferred>() + val exitB = CompletableDeferred>() + + val job = launch { + flowOf(1, 2).parMapUnordered { index -> + guaranteeCase({ + if (index == 2) latch.complete(Unit) + never() + }, { ex -> + if (index == 1) exitA.complete(Pair(i, ex)) + else exitB.complete(Pair(i2, ex)) + }) + }.collect() + } + + latch.await() + job.cancel() + + val (ii, ex) = exitA.await() + ii shouldBe i + ex.shouldBeTypeOf() + + val (ii2, ex2) = exitB.await() + ii2 shouldBe i2 + ex2.shouldBeTypeOf() + } + } } ) diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/jvmTest/kotlin/arrow/fx/coroutines/FlowJvmTest.kt b/arrow-libs/fx/arrow-fx-coroutines/src/jvmTest/kotlin/arrow/fx/coroutines/FlowJvmTest.kt index 7802c2f93b4..1af72548ee5 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/jvmTest/kotlin/arrow/fx/coroutines/FlowJvmTest.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/jvmTest/kotlin/arrow/fx/coroutines/FlowJvmTest.kt @@ -8,6 +8,8 @@ import io.kotest.property.Arb import io.kotest.property.arbitrary.int import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.toSet import kotlinx.coroutines.test.runBlockingTest import kotlin.time.ExperimentalTime import kotlin.time.milliseconds @@ -39,4 +41,22 @@ class FlowJvmTest : ArrowFxSpec(spec = { } } } + + "parMap - single thread - identity" { + single.use { ctx -> + checkAll(Arb.flow(Arb.int())) { flow -> + flow.parMap(ctx) { it } + .toList() shouldBe flow.toList() + } + } + } + + "parMapUnordered - single thread - identity" { + single.use { ctx -> + checkAll(Arb.flow(Arb.int())) { flow -> + flow.parMapUnordered(ctx) { it } + .toSet() shouldBe flow.toSet() + } + } + } })