diff --git a/bench/src/main/scala/cats/mtl/bench/StateBench.scala b/bench/src/main/scala/cats/mtl/bench/StateBench.scala index a4d69d47..0ff86073 100644 --- a/bench/src/main/scala/cats/mtl/bench/StateBench.scala +++ b/bench/src/main/scala/cats/mtl/bench/StateBench.scala @@ -1,5 +1,6 @@ package cats.mtl.bench +import cats.implicits._ import cats.data.StateT import cats.effect.IO import org.openjdk.jmh.annotations._ diff --git a/special/src/main/scala/cats/mtl/special/Newtype2.scala b/special/src/main/scala/cats/mtl/special/Newtype2.scala new file mode 100644 index 00000000..a82e26f2 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/Newtype2.scala @@ -0,0 +1,8 @@ +package cats.mtl.special + +private[special] trait Newtype2 { self => + private[special] type Base + private[special] trait Tag extends Any + type Type[A, B] <: Base with Tag +} + diff --git a/special/src/main/scala/cats/mtl/special/StateIO.scala b/special/src/main/scala/cats/mtl/special/StateIO.scala index 2b7bd420..0bd5a776 100644 --- a/special/src/main/scala/cats/mtl/special/StateIO.scala +++ b/special/src/main/scala/cats/mtl/special/StateIO.scala @@ -5,57 +5,95 @@ import cats.effect.concurrent.Ref import cats.effect._ import cats.mtl.MonadState + /* Performant counterpart to StateT[IO, S, A] */ -case class StateIO[S, A]private (private val value: IO[A]) { +object StateIO extends StateIOInstances with StateIOImpl with Newtype2 { + + @inline private[cats] def wrap[S, A](s: IO[A]): Type[S, A] = + s.asInstanceOf[Type[S, A]] + + @inline private[cats] def unwrap[S, A](e: Type[S, A]): IO[A] = + e.asInstanceOf[IO[A]] + + @inline + def liftF[S, A](fa: IO[A]): StateIO[S, A] = wrap(fa) + + def get[S]: StateIO[S, S] = liftF(refInstance[S].get) + + def modify[S](f: S => S): StateIO[S, Unit] = liftF(refInstance[S].update(f)) + + def set[S](s: S): StateIO[S, Unit] = liftF(refInstance[S].set(s)) + + def pure[S, A](a: A): StateIO[S, A] = liftF(IO.pure(a)) + + def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = + liftF(refInstance[S].get.flatMap(s => f(s).flatMap { + case (s, a) => refInstance[S].set(s).map(_ => a) + })) + + implicit def stateIOOps[S, A](v: StateIO[S, A]): StateIOOps[S, A] = + new StateIOOps[S, A](v) +} + +private[special] class StateIOOps[S, A](val sp: StateIO[S, A]) extends AnyVal { + def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) .unsafeRunAsync(f) def unsafeRunSync(s: S): (S, A) = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) .unsafeRunSync() def unsafeRunSyncA(s: S): A = - StateIO.refInstance[S].set(s).flatMap(_ => value).unsafeRunSync() + StateIO.refInstance[S].set(s).flatMap(_ => StateIO.unwrap(sp)).unsafeRunSync() def unsafeRunSyncS(s: S): S = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(_ => StateIO.refInstance[S].get)) + .flatMap(_ => StateIO.unwrap(sp).flatMap(_ => StateIO.refInstance[S].get)) .unsafeRunSync() +} - def flatMap[B](f: A => StateIO[S, B]): StateIO[S, B] = - StateIO(value.flatMap(a => f(a).value)) - def map[B](f: A => B): StateIO[S, B] = - StateIO(value.map(f)) +private[special] abstract class StateIOAsync[S] extends Async[StateIO[S, ?]] { + @inline + def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = + StateIO.wrap(StateIO.unwrap(fa).flatMap(a => StateIO.unwrap(f(a)))) - def get: StateIO[S, S] = - flatMap(_ => StateIO.get[S]) -} + def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(IO.suspend(StateIO.unwrap(thunk))) + def bracketCase[A, B](acquire: StateIO[S, A]) + (use: A => StateIO[S, B]) + (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(acquire) + .bracketCase(a => StateIO.unwrap(use(a)))((a, ec) => StateIO.unwrap(release(a, ec)))) + def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = + StateIO.liftF(IO.async(k)) -object StateIO extends StateIOImpl { + def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = + StateIO.liftF(IO.asyncF(cb => StateIO.unwrap(k(cb)))) - def liftF[S, A](fa: IO[A]): StateIO[S, A] = StateIO(fa) + override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(fa).map(f)) - def get[S]: StateIO[S, S] = StateIO(refInstance[S].get) + def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) - def modify[S](f: S => S): StateIO[S, Unit] = StateIO(refInstance[S].update(f)) + def raiseError[A](e: Throwable): StateIO[S, A] = StateIO.liftF(IO.raiseError(e)) - def set[S](s: S): StateIO[S, Unit] = StateIO(refInstance[S].set(s)) + def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(StateIO.unwrap(fa).handleErrorWith(e => StateIO.unwrap(f(e)))) - def pure[S, A](a: A): StateIO[S, A] = StateIO(IO.pure(a)) - - def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = - StateIO(refInstance[S].get.flatMap(s => f(s).flatMap { - case (s, a) => refInstance[S].set(s).map(_ => a) - })) + def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = + StateIO.liftF(Monad[IO].tailRecM(a)(a => StateIO.unwrap(f(a)))) +} - implicit def monadStateStateIO[S](implicit ctx: ContextShift[IO]): MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { - val monad: Monad[StateIO[S, ?]] = monadErrorStateIO +private[special] sealed abstract class StateIOInstances extends StateIOInstances0 { + implicit def stateIOMonadState[S]: MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { + val monad: Monad[StateIO[S, ?]] = stateIOAsync[S] def get: StateIO[S, S] = StateIO.get[S] @@ -63,62 +101,36 @@ object StateIO extends StateIOImpl { def set(s: S): StateIO[S, Unit] = StateIO.set(s) - def inspect[A](f: S => A): StateIO[S, A] = StateIO.get.map(f) + def inspect[A](f: S => A): StateIO[S, A] = monad.map(StateIO.get)(f) } - implicit def monadErrorStateIO[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = - new Concurrent[StateIO[S, ?]] { + implicit def stateIOConcurrent[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = + new StateIOAsync[S] with Concurrent[StateIO[S, ?]] { type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A] - def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = - fa.flatMap(f) - - override def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = - StateIO.liftF(IO.suspend(thunk.value)) - - def bracketCase[A, B](acquire: StateIO[S, A]) - (use: A => StateIO[S, B]) - (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = - StateIO(acquire.value.bracketCase(a => use(a).value)((a, ec) => release(a, ec).value)) - - def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = - StateIO.liftF(IO.async(k)) - - def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = - StateIO.liftF(IO.asyncF(cb => k(cb).value)) - override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] = - StateIO.liftF(IO.cancelable[A](cb => k(cb).value)) + StateIO.liftF(IO.cancelable[A](cb => StateIO.unwrap(k(cb)))) def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] = - StateIO.liftF[S, Fiber[A]](fa.value.start.map(fiberT)) + StateIO.liftF[S, Fiber[A]](StateIO.unwrap(fa).start.map(fiberT)) def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] = - StateIO(IO.racePair(fa.value, fb.value).map { + StateIO.liftF(IO.racePair(StateIO.unwrap(fa), StateIO.unwrap(fb)).map { case Left((a, fib)) => Left((a, fiberT(fib))) case Right((fib, b)) => Right((fiberT(fib), b)) }) - - override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = fa.map(f) - - def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) - - def raiseError[A](e: Throwable): StateIO[S, A] = StateIO(IO.raiseError(e)) - - def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = - StateIO(fa.value.handleErrorWith(e => f(e).value)) - - def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = - StateIO(Monad[IO].tailRecM(a)(a => f(a).value)) - protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] = - Fiber(StateIO(fiber.join), StateIO(fiber.cancel)) + Fiber(StateIO.liftF(fiber.join), StateIO.liftF(fiber.cancel)) } } +private[special] sealed abstract class StateIOInstances0 { + implicit def stateIOAsync[S]: Async[StateIO[S, ?]] = + new StateIOAsync[S]{} +} private[special] sealed trait StateIOImpl { diff --git a/special/src/main/scala/cats/mtl/special/package.scala b/special/src/main/scala/cats/mtl/special/package.scala new file mode 100644 index 00000000..d4d0cff6 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/package.scala @@ -0,0 +1,5 @@ +package cats.mtl + +package object special { + type StateIO[S, A] = StateIO.Type[S, A] +}