From b625d0ee9c079a498b69dbb936e8ae1960fa3126 Mon Sep 17 00:00:00 2001 From: Adam Fraser Date: Tue, 21 Nov 2023 17:18:35 -0800 Subject: [PATCH] optimize zquery#foreachbatched --- .../src/main/scala/zio/query/ZQuery.scala | 16 +++----- .../scala/zio/query/internal/Continue.scala | 36 ++++++++++++++++++ .../scala/zio/query/internal/Result.scala | 38 +++++++++++++++++++ 3 files changed, 80 insertions(+), 10 deletions(-) diff --git a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala index f8a8f013..b2e6425f 100644 --- a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala +++ b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala @@ -1091,16 +1091,12 @@ object ZQuery { f: A => ZQuery[R, E, B] )(implicit bf: BuildFrom[Collection[A], B, Collection[B]], trace: Trace): ZQuery[R, E, Collection[B]] = if (as.isEmpty) ZQuery.succeed(bf.newBuilder(as).result()) - else { - val iterator = as.iterator - var builder: ZQuery[R, E, Builder[B, Collection[B]]] = null - while (iterator.hasNext) { - val a = iterator.next() - if (builder eq null) builder = f(a).map(bf.newBuilder(as) += _) - else builder = builder.zipWithBatched(f(a))(_ += _) - } - builder.map(_.result()) - } + else + ZQuery( + ZIO + .foreach[R, Nothing, A, Result[R, E, B], Iterable](as)(f(_).step) + .map(Result.collectAllBatched(_).map(bf.fromSpecific(as))) + ) final def foreachBatched[R, E, A, B](as: Set[A])(fn: A => ZQuery[R, E, B])(implicit trace: Trace diff --git a/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala b/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala index 9dcbe44a..53fa09f4 100644 --- a/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala +++ b/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala @@ -169,6 +169,42 @@ private[query] object Continue { ): Continue[R, E, B] = Continue.get(promise.await) + /** + * Collects a collection of continuation into a continuation returning a + * collection of their results, batching requests to data sources. + */ + def collectAllBatched[R, E, A, Collection[+Element] <: Iterable[Element]]( + continues: Collection[Continue[R, E, A]] + )(implicit + bf: BuildFrom[Collection[Continue[R, E, A]], A, Collection[A]], + trace: Trace + ): Continue[R, E, Collection[A]] = + continues.zipWithIndex + .foldLeft[(Chunk[(ZQuery[R, E, A], Int)], Chunk[(IO[E, A], Int)])]((Chunk.empty, Chunk.empty)) { + case ((queries, ios), (continue, index)) => + continue match { + case Effect(query) => (queries :+ ((query, index)), ios) + case Get(io) => (queries, ios :+ ((io, index))) + } + } match { + case (Chunk(), ios) => + get(ZIO.collectAll(ios.map(_._1)).map(bf.fromSpecific(continues))) + case (queries, ios) => + val query = ZQuery.collectAllBatched(queries.map(_._1)).flatMap { as => + val array = Array.ofDim[AnyRef](continues.size) + as.zip(queries.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + ZQuery.fromZIO(ZIO.collectAll(ios.map(_._1))).map { as => + as.zip(ios.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + bf.fromSpecific(continues)(array.asInstanceOf[Array[A]]) + } + } + effect(query) + } + /** * Collects a collection of continuation into a continuation returning a * collection of their results, in parallel. diff --git a/zio-query/shared/src/main/scala/zio/query/internal/Result.scala b/zio-query/shared/src/main/scala/zio/query/internal/Result.scala index 39a77e79..be140491 100644 --- a/zio-query/shared/src/main/scala/zio/query/internal/Result.scala +++ b/zio-query/shared/src/main/scala/zio/query/internal/Result.scala @@ -111,6 +111,44 @@ private[query] object Result { def blocked[R, E, A](blockedRequests: BlockedRequests[R], continue: Continue[R, E, A]): Result[R, E, A] = Blocked(blockedRequests, continue) + /** + * Collects a collection of results into a single result. Blocked requests + * will be batched. + */ + def collectAllBatched[R, E, A, Collection[+Element] <: Iterable[Element]](results: Collection[Result[R, E, A]])( + implicit + bf: BuildFrom[Collection[Result[R, E, A]], A, Collection[A]], + trace: Trace + ): Result[R, E, Collection[A]] = + results.zipWithIndex + .foldLeft[(Chunk[((BlockedRequests[R], Continue[R, E, A]), Int)], Chunk[(A, Int)], Chunk[(Cause[E], Int)])]( + (Chunk.empty, Chunk.empty, Chunk.empty) + ) { case ((blocked, done, fails), (result, index)) => + result match { + case Blocked(br, c) => (blocked :+ (((br, c), index)), done, fails) + case Done(a) => (blocked, done :+ ((a, index)), fails) + case Fail(e) => (blocked, done, fails :+ ((e, index))) + } + } match { + case (Chunk(), done, Chunk()) => + Result.done(bf.fromSpecific(results)(done.map(_._1))) + case (blocked, done, Chunk()) => + val blockedRequests = blocked.map(_._1._1).foldLeft[BlockedRequests[R]](BlockedRequests.empty)(_ && _) + val continue = Continue.collectAllBatched(blocked.map(_._1._2)).map { as => + val array = Array.ofDim[AnyRef](results.size) + as.zip(blocked.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + done.foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + bf.fromSpecific(results)(array.asInstanceOf[Array[A]]) + } + Result.blocked(blockedRequests, continue) + case (_, _, fail) => + Result.fail(fail.map(_._1).foldLeft[Cause[E]](Cause.empty)(_ && _)) + } + /** * Collects a collection of results into a single result. Blocked requests and * their continuations will be executed in parallel.