Skip to content

Commit

Permalink
Flow parMap & parMapUnordered (#2453)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev authored Jul 26, 2021
1 parent 71d48e1 commit 6eef516
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import io.kotest.property.arbitrary.char
import io.kotest.property.arbitrary.choice
import io.kotest.property.arbitrary.constant
import io.kotest.property.arbitrary.int
import io.kotest.property.arbitrary.list
import io.kotest.property.arbitrary.long
import io.kotest.property.arbitrary.map
import io.kotest.property.arbitrary.string
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
Expand All @@ -38,6 +41,9 @@ public data class SideEffect(var counter: Int = 0) {
}
}

public fun <A> Arb.Companion.flow(arbA: Arb<A>): Arb<Flow<A>> =
Arb.list(arbA).map { it.asFlow() }

public fun Arb.Companion.throwable(): Arb<Throwable> =
Arb.string().map(::RuntimeException)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@

package arrow.fx.coroutines

import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.DEFAULT_CONCURRENCY
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.flattenMerge
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.retryWhen
import kotlin.coroutines.CoroutineContext
import kotlin.jvm.JvmMultifileClass
import kotlin.jvm.JvmName

Expand Down Expand Up @@ -57,3 +71,91 @@ public fun <A, B> Flow<A>.retry(schedule: Schedule<Throwable, B>): Flow<A> = flo
emit(it)
}
}

/**
* Like [Flow.map], but will evaluate effects in parallel, emitting the results
* downstream in the same order as the input stream. The number of concurrent effects
* is limited by [concurrency].
*
* See [parMapUnordered] if there is no requirement to retain the order of the original stream.
*
* ```kotlin:ank:playground
* import kotlinx.coroutines.delay
* import kotlinx.coroutines.flow.flowOf
* import kotlinx.coroutines.flow.toList
* import kotlinx.coroutines.flow.collect
* import arrow.fx.coroutines.parMap
*
* //sampleStart
* suspend fun main(): Unit {
* flowOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
* .parMap { a ->
* delay(100)
* a
* }.toList() // [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
* }
* //sampleEnd
* ```
*/
@FlowPreview
@ExperimentalCoroutinesApi
public inline fun <A, B> Flow<A>.parMap(
context: CoroutineContext = Dispatchers.Default,
concurrency: Int = DEFAULT_CONCURRENCY,
crossinline transform: suspend CoroutineScope.(a: A) -> B
): Flow<B> =
channelFlow<Deferred<B>> {
map { a ->
val deferred = CompletableDeferred<B>()
send(deferred)
flow<Unit> {
try {
val b = transform(a)
deferred.complete(b)
} catch (e: Throwable) {
require(deferred.completeExceptionally(e))
throw e
}
}.flowOn(context)
}
.flattenMerge(concurrency)
.launchIn(this)
}
.buffer(concurrency)
.map(Deferred<B>::await)

/**
* Like [map], but will evaluate effects in parallel, emitting the results downstream.
* The number of concurrent effects is limited by [concurrency].
*
* See [parMap] if retaining the original order of the stream is required.
*
* ```kotlin:ank:playground
* import kotlinx.coroutines.delay
* import kotlinx.coroutines.flow.flowOf
* import kotlinx.coroutines.flow.toList
* import kotlinx.coroutines.flow.collect
* import arrow.fx.coroutines.parMapUnordered
*
* //sampleStart
* suspend fun main(): Unit {
* flowOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
* .parMapUnordered { a ->
* delay(100)
* a
* }.toList() // [3, 5, 4, 6, 2, 8, 7, 1, 9, 10]
* }
* //sampleEnd
* ```
*/
@FlowPreview
public inline fun <A, B> Flow<A>.parMapUnordered(
ctx: CoroutineContext = Dispatchers.Default,
concurrency: Int = DEFAULT_CONCURRENCY,
crossinline transform: suspend (a: A) -> B
): Flow<B> =
map { o ->
flow {
emit(transform(o))
}.flowOn(ctx)
}.flattenMerge(concurrency)
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
package arrow.fx.coroutines

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.assertions.fail
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeTypeOf
import io.kotest.property.Arb
import io.kotest.property.arbitrary.int
import io.kotest.property.arbitrary.positiveInts
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.reduce
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.flow.toSet
import kotlinx.coroutines.launch
import kotlin.time.ExperimentalTime

@ExperimentalTime
class FlowTest : ArrowFxSpec(
spec = {

"Retry - flow fails" {
val bang = RuntimeException("Bang!")

checkAll(Arb.int(), Arb.positiveInts(10)) { a, n ->
var counter = 0
val e = shouldThrow<RuntimeException> {
val e = assertThrowable {
flow {
emit(a)
if (++counter <= 11) throw RuntimeException("Bang!")
if (++counter <= 11) throw bang
}.retry(Schedule.recurs(n))
.collect()
}
e.message shouldBe "Bang!"
e shouldBe bang
}
}

Expand All @@ -40,5 +49,210 @@ class FlowTest : ArrowFxSpec(
sum shouldBe a * 6
}
}

"parMap - concurrency = 1 - identity" {
checkAll(Arb.flow(Arb.int())) { flow ->
flow.parMap(concurrency = 1) { it }
.toList() shouldBe flow.toList()
}
}

"parMap - runs in parallel" {
checkAll(Arb.int(), Arb.int(1..2)) { i, n ->
val latch = CompletableDeferred<Int>()
flowOf(1, 2).parMap { index ->
if (index == n) latch.await()
else {
latch.complete(i)
null
}
}.toList().filterNotNull() shouldBe listOf(i)
}
}

"parMap - triggers cancel signal" {
checkAll(Arb.int(), Arb.int(1..2)) { i, n ->
val latch = CompletableDeferred<Unit>()
val exit = CompletableDeferred<Pair<Int, ExitCase>>()

assertThrowable {
flowOf(1, 2).parMap { index ->
if (index == n) {
guaranteeCase({
latch.complete(Unit)
never<Unit>()
}, { ex -> exit.complete(Pair(i, ex)) })
} else {
latch.await()
throw CancellationException(null, null)
}
}.collect()
fail("Cannot reach here. CancellationException should be thrown.")
}.shouldBeTypeOf<CancellationException>()

val (ii, ex) = exit.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()
}
}

"parMap - exception in parMap cancels all running tasks" {
checkAll(Arb.int(), Arb.throwable(), Arb.int(1..2)) { i, e, n ->
val latch = CompletableDeferred<Unit>()
val exit = CompletableDeferred<Pair<Int, ExitCase>>()

assertThrowable {
flowOf(1, 2).parMap { index ->
if (index == n) {
guaranteeCase({
latch.complete(Unit)
never<Unit>()
}, { ex -> exit.complete(Pair(i, ex)) })
} else {
latch.await()
throw e
}
}.collect()
fail("Cannot reach here. $e should be thrown.")
} shouldBe e

val (ii, ex) = exit.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()
}
}

"parMap - Cancelling parMap cancels all running jobs" {
checkAll(Arb.int(), Arb.int()) { i, i2 ->
val latch = CompletableDeferred<Unit>()
val exitA = CompletableDeferred<Pair<Int, ExitCase>>()
val exitB = CompletableDeferred<Pair<Int, ExitCase>>()

val job = launch {
flowOf(1, 2).parMap { index ->
guaranteeCase({
if (index == 2) latch.complete(Unit)
never<Unit>()
}, { ex ->
if (index == 1) exitA.complete(Pair(i, ex))
else exitB.complete(Pair(i2, ex))
})
}.collect()
}

latch.await()
job.cancel()

val (ii, ex) = exitA.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()

val (ii2, ex2) = exitB.await()
ii2 shouldBe i2
ex2.shouldBeTypeOf<ExitCase.Cancelled>()
}
}
"parMapUnordered - concurrency = 1 - identity" {
checkAll(Arb.flow(Arb.int())) { flow ->
flow.parMapUnordered(concurrency = 1) { it }
.toSet() shouldBe flow.toSet()
}
}

"parMapUnordered - runs in parallel" {
checkAll(Arb.int(), Arb.int(1..2)) { i, n ->
val latch = CompletableDeferred<Int>()
flowOf(1, 2).parMapUnordered { index ->
if (index == n) latch.await()
else {
latch.complete(i)
null
}
}.toSet().filterNotNull() shouldBe setOf(i)
}
}

"parMapUnordered - triggers cancel signal" {
checkAll(Arb.int(), Arb.int(1..2)) { i, n ->
val latch = CompletableDeferred<Unit>()
val exit = CompletableDeferred<Pair<Int, ExitCase>>()

assertThrowable {
flowOf(1, 2).parMapUnordered { index ->
if (index == n) {
guaranteeCase({
latch.complete(Unit)
never<Unit>()
}, { ex -> exit.complete(Pair(i, ex)) })
} else {
latch.await()
throw CancellationException(null, null)
}
}.collect()
fail("Cannot reach here. CancellationException should be thrown.")
}.shouldBeTypeOf<CancellationException>()

val (ii, ex) = exit.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()
}
}

"parMapUnordered - exception in parMap cancels all running tasks" {
checkAll(Arb.int(), Arb.throwable(), Arb.int(1..2)) { i, e, n ->
val latch = CompletableDeferred<Unit>()
val exit = CompletableDeferred<Pair<Int, ExitCase>>()

assertThrowable {
flowOf(1, 2).parMapUnordered { index ->
if (index == n) {
guaranteeCase({
latch.complete(Unit)
never<Unit>()
}, { ex -> exit.complete(Pair(i, ex)) })
} else {
latch.await()
throw e
}
}.collect()
fail("Cannot reach here. $e should be thrown.")
} shouldBe e

val (ii, ex) = exit.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()
}
}

"parMapUnordered - Cancelling parMap cancels all running jobs" {
checkAll(Arb.int(), Arb.int()) { i, i2 ->
val latch = CompletableDeferred<Unit>()
val exitA = CompletableDeferred<Pair<Int, ExitCase>>()
val exitB = CompletableDeferred<Pair<Int, ExitCase>>()

val job = launch {
flowOf(1, 2).parMapUnordered { index ->
guaranteeCase({
if (index == 2) latch.complete(Unit)
never<Unit>()
}, { ex ->
if (index == 1) exitA.complete(Pair(i, ex))
else exitB.complete(Pair(i2, ex))
})
}.collect()
}

latch.await()
job.cancel()

val (ii, ex) = exitA.await()
ii shouldBe i
ex.shouldBeTypeOf<ExitCase.Cancelled>()

val (ii2, ex2) = exitB.await()
ii2 shouldBe i2
ex2.shouldBeTypeOf<ExitCase.Cancelled>()
}
}
}
)
Loading

0 comments on commit 6eef516

Please sign in to comment.