Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize init and restoration of FiberRefs when running queries #507

Merged
merged 1 commit into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ class ZQueryBenchmark {
def zQueryRunSucceedNowBenchmark() =
unsafeRunZIO(ZIO.collectAllDiscard(qs1))

@Benchmark
def zQuerySingleRunSucceedNowBenchmark() =
unsafeRunZIO(qs1.head)

@Benchmark
@OperationsPerInvocation(1000)
def zQueryRunSucceedBenchmark() =
unsafeRunZIO(ZIO.collectAllDiscard(qs2))

@Benchmark
def zQuerySingleRunSucceedBenchmark() =
unsafeRunZIO(qs2.head)
}
46 changes: 34 additions & 12 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.query

import zio._
import zio.query.ZQuery.disabledCache
import zio.query.internal._
import zio.stacktracer.TracingImplicits.disableAutoTrace

Expand Down Expand Up @@ -550,25 +551,46 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = {
import ZQuery.{currentCache, currentScope}

def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = {
val oldValue = state.getFiberRefOrNull(fiberRef)
state.setFiberRef(fiberRef, newValue)
oldValue
def resetRef[V <: AnyRef](
fid: FiberId.Runtime,
oldRefs: FiberRefs,
newRefs: FiberRefs
)(
fiberRef: FiberRef[V]
): FiberRefs = {
val oldValue = oldRefs.getOrNull(fiberRef)
if (oldValue ne null) newRefs.updatedAs(fid)(fiberRef, oldValue) else newRefs.delete(fiberRef)
}

def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit =
if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef)

asExitOrElse(null) match {
case null =>
ZIO.uninterruptibleMask { restore =>
ZIO.withFiberRuntime[R, E, A] { (state, _) =>
val scope = QueryScope.make()
val oldCache = setRef(state, currentCache, Some(cache))
val oldScope = setRef(state, currentScope, scope)
// NOTE: Running a ZQuery requires up to 3 FiberRefs, which can be expensive to use `locally` with for simple queries.
// Therefore, we handle them all together to avoid the added penalty of running `locally` 3 times
val fid = state.id
val scope = QueryScope.make()
val oldRefs = state.getFiberRefs(false)
val newRefs = {
val refs = oldRefs.updatedAs(fid)(currentCache, Some(cache)).updatedAs(fid)(currentScope, scope)
if (refs.getOrNull(disabledCache) ne null)
refs.delete(disabledCache)
else refs
}
state.setFiberRefs(newRefs)
restore(runToZIO).exitWith { exit =>
resetRef(state, currentCache, oldCache)
resetRef(state, currentScope, oldScope)
val curRefs = state.getFiberRefs(false)
if (curRefs eq newRefs) {
// Cheap and common: FiberRefs were not modified during the execution so we just replace them with the old ones
state.setFiberRefs(oldRefs)
} else {
// FiberRefs were mdified so we need to manually revert each one
var revertedRefs = oldRefs
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentCache)
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentScope)
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(disabledCache)
state.setFiberRefs(revertedRefs)
}
scope.closeAndExitWith(exit)
}
}
Expand Down
10 changes: 10 additions & 0 deletions zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ object ZQuerySpec extends ZIOBaseSpec {

q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) }
},
test("disabling caching is reentrant safe") {
val q =
for {
c1 <- ZQuery.fromZIO(ZQuery.currentCache.get)
c2 <- ZQuery.fromZIO(ZQuery.fromZIO(ZQuery.currentCache.get).cached.run).uncached
c3 <- ZQuery.fromZIO(ZQuery.currentCache.get)
} yield (c1, c2, c3)

q.run.map { case (c1, c2, c3) => assertTrue(c1.isDefined, c2.isDefined, c1 == c3, c1 != c2) }
},
test("scope is reentrant safe") {
val q =
for {
Expand Down