From a4d95d26c77b43acb5a6d5babd23a4c1a1ece241 Mon Sep 17 00:00:00 2001 From: Luka Jacobowitz Date: Fri, 12 Oct 2018 15:52:24 +0200 Subject: [PATCH 1/3] New special and bench modules + StateIO --- .../scala/cats/mtl/bench/StateBench.scala | 36 +++++ build.sbt | 31 ++++- project/Dependencies.scala | 10 +- project/plugins.sbt | 1 + .../main/scala/cats/mtl/special/StateIO.scala | 130 ++++++++++++++++++ 5 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 bench/src/main/scala/cats/mtl/bench/StateBench.scala create mode 100644 special/src/main/scala/cats/mtl/special/StateIO.scala diff --git a/bench/src/main/scala/cats/mtl/bench/StateBench.scala b/bench/src/main/scala/cats/mtl/bench/StateBench.scala new file mode 100644 index 00000000..a4d69d47 --- /dev/null +++ b/bench/src/main/scala/cats/mtl/bench/StateBench.scala @@ -0,0 +1,36 @@ +package cats.mtl.bench + +import cats.data.StateT +import cats.effect.IO +import org.openjdk.jmh.annotations._ +import java.util.concurrent.TimeUnit + +import cats.mtl.special.StateIO + +@BenchmarkMode(Array(Mode.Throughput)) +@OutputTimeUnit(TimeUnit.SECONDS) +class StateBench { + + def leftFlatMapSpecial(bound: Int): Int = { + def loop(i: Int): StateIO[Int, Int] = + if (i > bound) StateIO.pure(i) + else StateIO.pure[Int, Int](i + 1).flatMap(loop) + + StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncS(0) + } + + def leftFlatMapStateT(bound: Int): Int = { + def loop(i: Int): StateT[IO, Int, Int] = + if (i > bound) StateT.pure[IO, Int, Int](i) + else StateT.pure[IO, Int, Int](i + 1).flatMap(loop) + + StateT.pure[IO, Int, Int](0).flatMap(loop).runS(0).unsafeRunSync() + } + + @Benchmark + def leftAssociatedBindSpecial(): Int = leftFlatMapSpecial(100000) + + @Benchmark + def leftAssociatedBindStateT(): Int = leftFlatMapStateT(100000) + +} diff --git a/build.sbt b/build.sbt index 8add1bf5..e622e004 100644 --- a/build.sbt +++ b/build.sbt @@ -118,11 +118,32 @@ val laws = crossProject.crossType(CrossType.Pure) val lawsJVM = laws.jvm val lawsJS = laws.js +val special = crossProject.crossType(CrossType.Pure) + .dependsOn(core) + .settings(moduleName := "cats-mtl-special", name := "Cats MTL Special") + .settings(Settings.coreSettings:_*) + .settings(Dependencies.catsEffect: _*) + .jsSettings(Settings.commonJsSettings:_*) + .jvmSettings(Settings.commonJvmSettings:_*) + .jsSettings(coverageEnabled := false) + +val specialJVM = special.jvm +val specialJS = special.js + +lazy val bench = project.dependsOn(coreJVM, specialJVM) + .settings(moduleName := "cats-mtl-bench") + .settings(Settings.coreSettings:_*) + .settings(Publishing.noPublishSettings) + .settings(Settings.commonJvmSettings:_*) + .settings(coverageEnabled := false) + .enablePlugins(JmhPlugin) + val tests = crossProject.crossType(CrossType.Pure) - .dependsOn(core, laws) + .dependsOn(core, laws, special) .settings(moduleName := "cats-mtl-tests") .settings(Settings.coreSettings:_*) .settings(Dependencies.catsBundle:_*) + .settings(Dependencies.catsEffectLaws:_*) .settings(Dependencies.discipline:_*) .settings(Dependencies.scalaCheck:_*) .settings(Publishing.noPublishSettings:_*) @@ -155,16 +176,16 @@ val catsMtlJVM = project.in(file(".catsJVM")) .settings(Settings.coreSettings) .settings(Settings.commonJvmSettings) .settings(Publishing.noPublishSettings) - .aggregate(coreJVM, lawsJVM, testsJVM, jvm, docs) - .dependsOn(coreJVM, lawsJVM, testsJVM % "test-internal -> test", jvm) + .aggregate(coreJVM, lawsJVM, specialJVM, testsJVM, jvm, docs) + .dependsOn(coreJVM, lawsJVM, specialJVM, testsJVM % "test-internal -> test", jvm) val catsMtlJS = project.in(file(".catsJS")) .settings(moduleName := "cats-mtl") .settings(Settings.coreSettings) .settings(Settings.commonJsSettings) .settings(Publishing.noPublishSettings) - .aggregate(coreJS, lawsJS, testsJS, js) - .dependsOn(coreJS, lawsJS, testsJS % "test-internal -> test", js) + .aggregate(coreJS, lawsJS, specialJS, testsJS, js) + .dependsOn(coreJS, lawsJS, specialJS, testsJS % "test-internal -> test", js) .enablePlugins(ScalaJSPlugin) val catsMtl = project.in(file(".")) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index a8d3c47b..1637f8db 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -15,7 +15,7 @@ object Dependencies { val simulacrum = "0.13.0" val machinist = "0.6.5" val cats = "1.4.0" - val shapeless = "2.3.3" + val catsEffect = "1.0.0" } val acyclic: Seq[Def.Setting[_]] = Def.settings( @@ -54,8 +54,12 @@ object Dependencies { "org.typelevel" %%% "cats-core" % Versions.cats )) - val shapeless: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( - "com.chuusai" %%% "shapeless" % Versions.shapeless + val catsEffect: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-effect" % Versions.catsEffect + )) + + val catsEffectLaws: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-effect-laws" % Versions.catsEffect )) val simulacrumAndMachinist: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index 71d603af..47f2e6bc 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -12,3 +12,4 @@ addSbtPlugin("com.47deg" % "sbt-microsites" % "0.7.23") addSbtPlugin("com.dwijnand" % "sbt-travisci" % "1.1.3") addSbtPlugin("org.lyranthe.sbt" % "partial-unification" % "1.1.2") addSbtPlugin("com.lucidchart" % "sbt-scalafmt-coursier" % "1.15") +addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4") diff --git a/special/src/main/scala/cats/mtl/special/StateIO.scala b/special/src/main/scala/cats/mtl/special/StateIO.scala new file mode 100644 index 00000000..2b7bd420 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/StateIO.scala @@ -0,0 +1,130 @@ +package cats.mtl.special + +import cats.Monad +import cats.effect.concurrent.Ref +import cats.effect._ +import cats.mtl.MonadState + +/* Performant counterpart to StateT[IO, S, A] */ +case class StateIO[S, A]private (private val value: IO[A]) { + def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit = + StateIO.refInstance[S].set(s) + .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .unsafeRunAsync(f) + + def unsafeRunSync(s: S): (S, A) = + StateIO.refInstance[S].set(s) + .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .unsafeRunSync() + + def unsafeRunSyncA(s: S): A = + StateIO.refInstance[S].set(s).flatMap(_ => value).unsafeRunSync() + + def unsafeRunSyncS(s: S): S = + StateIO.refInstance[S].set(s) + .flatMap(_ => value.flatMap(_ => StateIO.refInstance[S].get)) + .unsafeRunSync() + + def flatMap[B](f: A => StateIO[S, B]): StateIO[S, B] = + StateIO(value.flatMap(a => f(a).value)) + + def map[B](f: A => B): StateIO[S, B] = + StateIO(value.map(f)) + + def get: StateIO[S, S] = + flatMap(_ => StateIO.get[S]) +} + + + +object StateIO extends StateIOImpl { + + def liftF[S, A](fa: IO[A]): StateIO[S, A] = StateIO(fa) + + def get[S]: StateIO[S, S] = StateIO(refInstance[S].get) + + def modify[S](f: S => S): StateIO[S, Unit] = StateIO(refInstance[S].update(f)) + + def set[S](s: S): StateIO[S, Unit] = StateIO(refInstance[S].set(s)) + + def pure[S, A](a: A): StateIO[S, A] = StateIO(IO.pure(a)) + + def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = + StateIO(refInstance[S].get.flatMap(s => f(s).flatMap { + case (s, a) => refInstance[S].set(s).map(_ => a) + })) + + implicit def monadStateStateIO[S](implicit ctx: ContextShift[IO]): MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { + val monad: Monad[StateIO[S, ?]] = monadErrorStateIO + + def get: StateIO[S, S] = StateIO.get[S] + + def modify(f: S => S): StateIO[S, Unit] = StateIO.modify(f) + + def set(s: S): StateIO[S, Unit] = StateIO.set(s) + + def inspect[A](f: S => A): StateIO[S, A] = StateIO.get.map(f) + } + + implicit def monadErrorStateIO[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = + new Concurrent[StateIO[S, ?]] { + + type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A] + + def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = + fa.flatMap(f) + + override def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(IO.suspend(thunk.value)) + + def bracketCase[A, B](acquire: StateIO[S, A]) + (use: A => StateIO[S, B]) + (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = + StateIO(acquire.value.bracketCase(a => use(a).value)((a, ec) => release(a, ec).value)) + + def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = + StateIO.liftF(IO.async(k)) + + def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = + StateIO.liftF(IO.asyncF(cb => k(cb).value)) + + override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] = + StateIO.liftF(IO.cancelable[A](cb => k(cb).value)) + + def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] = + StateIO.liftF[S, Fiber[A]](fa.value.start.map(fiberT)) + + + def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] = + StateIO(IO.racePair(fa.value, fb.value).map { + case Left((a, fib)) => Left((a, fiberT(fib))) + case Right((fib, b)) => Right((fiberT(fib), b)) + }) + + + override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = fa.map(f) + + def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) + + def raiseError[A](e: Throwable): StateIO[S, A] = StateIO(IO.raiseError(e)) + + def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = + StateIO(fa.value.handleErrorWith(e => f(e).value)) + + def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = + StateIO(Monad[IO].tailRecM(a)(a => f(a).value)) + + protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] = + Fiber(StateIO(fiber.join), StateIO(fiber.cancel)) + } +} + + + +private[special] sealed trait StateIOImpl { + + /* There be dragons */ + private val ref: Ref[IO, Any] = Ref.unsafe[IO, Any](null) + + private[special] def refInstance[S]: Ref[IO, S] = ref.asInstanceOf[Ref[IO, S]] +} From 1d5a4664bb464e23dd69bf8adaa36bb23c43f7fa Mon Sep 17 00:00:00 2001 From: Luka Jacobowitz Date: Fri, 12 Oct 2018 17:07:29 +0200 Subject: [PATCH 2/3] Convert to newtype --- .../scala/cats/mtl/bench/StateBench.scala | 5 +- .../scala/cats/mtl/special/Newtype2.scala | 8 ++ .../main/scala/cats/mtl/special/StateIO.scala | 136 ++++++++++-------- .../main/scala/cats/mtl/special/package.scala | 5 + 4 files changed, 91 insertions(+), 63 deletions(-) create mode 100644 special/src/main/scala/cats/mtl/special/Newtype2.scala create mode 100644 special/src/main/scala/cats/mtl/special/package.scala diff --git a/bench/src/main/scala/cats/mtl/bench/StateBench.scala b/bench/src/main/scala/cats/mtl/bench/StateBench.scala index a4d69d47..401cf253 100644 --- a/bench/src/main/scala/cats/mtl/bench/StateBench.scala +++ b/bench/src/main/scala/cats/mtl/bench/StateBench.scala @@ -1,5 +1,6 @@ package cats.mtl.bench +import cats.implicits._ import cats.data.StateT import cats.effect.IO import org.openjdk.jmh.annotations._ @@ -16,7 +17,7 @@ class StateBench { if (i > bound) StateIO.pure(i) else StateIO.pure[Int, Int](i + 1).flatMap(loop) - StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncS(0) + StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncA(0) } def leftFlatMapStateT(bound: Int): Int = { @@ -24,7 +25,7 @@ class StateBench { if (i > bound) StateT.pure[IO, Int, Int](i) else StateT.pure[IO, Int, Int](i + 1).flatMap(loop) - StateT.pure[IO, Int, Int](0).flatMap(loop).runS(0).unsafeRunSync() + StateT.pure[IO, Int, Int](0).flatMap(loop).runA(0).unsafeRunSync() } @Benchmark diff --git a/special/src/main/scala/cats/mtl/special/Newtype2.scala b/special/src/main/scala/cats/mtl/special/Newtype2.scala new file mode 100644 index 00000000..a82e26f2 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/Newtype2.scala @@ -0,0 +1,8 @@ +package cats.mtl.special + +private[special] trait Newtype2 { self => + private[special] type Base + private[special] trait Tag extends Any + type Type[A, B] <: Base with Tag +} + diff --git a/special/src/main/scala/cats/mtl/special/StateIO.scala b/special/src/main/scala/cats/mtl/special/StateIO.scala index 2b7bd420..5912c563 100644 --- a/special/src/main/scala/cats/mtl/special/StateIO.scala +++ b/special/src/main/scala/cats/mtl/special/StateIO.scala @@ -5,57 +5,95 @@ import cats.effect.concurrent.Ref import cats.effect._ import cats.mtl.MonadState + /* Performant counterpart to StateT[IO, S, A] */ -case class StateIO[S, A]private (private val value: IO[A]) { +object StateIO extends StateIOInstances with StateIOImpl with Newtype2 { + + private[cats] def wrap[S, A](s: IO[A]): Type[S, A] = + s.asInstanceOf[Type[S, A]] + + private[cats] def unwrap[S, A](e: Type[S, A]): IO[A] = + e.asInstanceOf[IO[A]] + + @inline + def liftF[S, A](fa: IO[A]): StateIO[S, A] = wrap(fa) + + def get[S]: StateIO[S, S] = liftF(refInstance[S].get) + + def modify[S](f: S => S): StateIO[S, Unit] = liftF(refInstance[S].update(f)) + + def set[S](s: S): StateIO[S, Unit] = liftF(refInstance[S].set(s)) + + def pure[S, A](a: A): StateIO[S, A] = liftF(IO.pure(a)) + + def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = + liftF(refInstance[S].get.flatMap(s => f(s).flatMap { + case (s, a) => refInstance[S].set(s).map(_ => a) + })) + + implicit def stateIOOps[S, A](v: StateIO[S, A]): StateIOOps[S, A] = + new StateIOOps[S, A](v) +} + +private[special] class StateIOOps[S, A](val sp: StateIO[S, A]) extends AnyVal { + def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) .unsafeRunAsync(f) def unsafeRunSync(s: S): (S, A) = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) .unsafeRunSync() def unsafeRunSyncA(s: S): A = - StateIO.refInstance[S].set(s).flatMap(_ => value).unsafeRunSync() + StateIO.refInstance[S].set(s).flatMap(_ => StateIO.unwrap(sp)).unsafeRunSync() def unsafeRunSyncS(s: S): S = StateIO.refInstance[S].set(s) - .flatMap(_ => value.flatMap(_ => StateIO.refInstance[S].get)) + .flatMap(_ => StateIO.unwrap(sp).flatMap(_ => StateIO.refInstance[S].get)) .unsafeRunSync() +} - def flatMap[B](f: A => StateIO[S, B]): StateIO[S, B] = - StateIO(value.flatMap(a => f(a).value)) - def map[B](f: A => B): StateIO[S, B] = - StateIO(value.map(f)) +private[special] abstract class StateIOAsync[S] extends Async[StateIO[S, ?]] { - def get: StateIO[S, S] = - flatMap(_ => StateIO.get[S]) -} + def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = + StateIO.wrap(StateIO.unwrap(fa).flatMap(a => StateIO.unwrap(f(a)))) + def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(IO.suspend(StateIO.unwrap(thunk))) + def bracketCase[A, B](acquire: StateIO[S, A]) + (use: A => StateIO[S, B]) + (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(acquire) + .bracketCase(a => StateIO.unwrap(use(a)))((a, ec) => StateIO.unwrap(release(a, ec)))) -object StateIO extends StateIOImpl { + def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = + StateIO.liftF(IO.async(k)) - def liftF[S, A](fa: IO[A]): StateIO[S, A] = StateIO(fa) + def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = + StateIO.liftF(IO.asyncF(cb => StateIO.unwrap(k(cb)))) - def get[S]: StateIO[S, S] = StateIO(refInstance[S].get) + override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(fa).map(f)) - def modify[S](f: S => S): StateIO[S, Unit] = StateIO(refInstance[S].update(f)) + def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) - def set[S](s: S): StateIO[S, Unit] = StateIO(refInstance[S].set(s)) + def raiseError[A](e: Throwable): StateIO[S, A] = StateIO.liftF(IO.raiseError(e)) - def pure[S, A](a: A): StateIO[S, A] = StateIO(IO.pure(a)) + def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(StateIO.unwrap(fa).handleErrorWith(e => StateIO.unwrap(f(e)))) - def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = - StateIO(refInstance[S].get.flatMap(s => f(s).flatMap { - case (s, a) => refInstance[S].set(s).map(_ => a) - })) + def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = + StateIO.liftF(Monad[IO].tailRecM(a)(a => StateIO.unwrap(f(a)))) +} - implicit def monadStateStateIO[S](implicit ctx: ContextShift[IO]): MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { - val monad: Monad[StateIO[S, ?]] = monadErrorStateIO +private[special] sealed abstract class StateIOInstances extends StateIOInstances0 { + implicit def stateIOMonadState[S]: MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { + val monad: Monad[StateIO[S, ?]] = stateIOAsync[S] def get: StateIO[S, S] = StateIO.get[S] @@ -63,68 +101,44 @@ object StateIO extends StateIOImpl { def set(s: S): StateIO[S, Unit] = StateIO.set(s) - def inspect[A](f: S => A): StateIO[S, A] = StateIO.get.map(f) + def inspect[A](f: S => A): StateIO[S, A] = monad.map(StateIO.get)(f) } - implicit def monadErrorStateIO[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = - new Concurrent[StateIO[S, ?]] { + implicit def stateIOConcurrent[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = + new StateIOAsync[S] with Concurrent[StateIO[S, ?]] { type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A] - def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = - fa.flatMap(f) - - override def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = - StateIO.liftF(IO.suspend(thunk.value)) - - def bracketCase[A, B](acquire: StateIO[S, A]) - (use: A => StateIO[S, B]) - (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = - StateIO(acquire.value.bracketCase(a => use(a).value)((a, ec) => release(a, ec).value)) - - def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = - StateIO.liftF(IO.async(k)) - - def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = - StateIO.liftF(IO.asyncF(cb => k(cb).value)) - override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] = - StateIO.liftF(IO.cancelable[A](cb => k(cb).value)) + StateIO.liftF(IO.cancelable[A](cb => StateIO.unwrap(k(cb)))) def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] = - StateIO.liftF[S, Fiber[A]](fa.value.start.map(fiberT)) + StateIO.liftF[S, Fiber[A]](StateIO.unwrap(fa).start.map(fiberT)) def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] = - StateIO(IO.racePair(fa.value, fb.value).map { + StateIO.liftF(IO.racePair(StateIO.unwrap(fa), StateIO.unwrap(fb)).map { case Left((a, fib)) => Left((a, fiberT(fib))) case Right((fib, b)) => Right((fiberT(fib), b)) }) - - override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = fa.map(f) - - def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) - - def raiseError[A](e: Throwable): StateIO[S, A] = StateIO(IO.raiseError(e)) - - def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = - StateIO(fa.value.handleErrorWith(e => f(e).value)) - - def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = - StateIO(Monad[IO].tailRecM(a)(a => f(a).value)) - protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] = - Fiber(StateIO(fiber.join), StateIO(fiber.cancel)) + Fiber(StateIO.liftF(fiber.join), StateIO.liftF(fiber.cancel)) } } +private[special] sealed abstract class StateIOInstances0 { + implicit def stateIOAsync[S]: Async[StateIO[S, ?]] = + new StateIOAsync[S]{} +} private[special] sealed trait StateIOImpl { /* There be dragons */ + // scalastyle:off null private val ref: Ref[IO, Any] = Ref.unsafe[IO, Any](null) + // scalastyle:on null private[special] def refInstance[S]: Ref[IO, S] = ref.asInstanceOf[Ref[IO, S]] } diff --git a/special/src/main/scala/cats/mtl/special/package.scala b/special/src/main/scala/cats/mtl/special/package.scala new file mode 100644 index 00000000..d4d0cff6 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/package.scala @@ -0,0 +1,5 @@ +package cats.mtl + +package object special { + type StateIO[S, A] = StateIO.Type[S, A] +} From f74c9370d1c217d8fc71ffe3d072eabd79f014a1 Mon Sep 17 00:00:00 2001 From: Luka Jacobowitz Date: Fri, 12 Oct 2018 22:44:47 +0200 Subject: [PATCH 3/3] Add Lawtests --- .../scala/cats/mtl/tests/StateIOTests.scala | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/src/test/scala/cats/mtl/tests/StateIOTests.scala diff --git a/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala b/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala new file mode 100644 index 00000000..2381913d --- /dev/null +++ b/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala @@ -0,0 +1,49 @@ +package cats +package mtl +package tests + + +import cats.implicits._ +import cats.laws.discipline.arbitrary._ +import cats.laws.discipline._ +import cats.effect.IO +import cats.effect.concurrent.Deferred +import cats.effect.laws.discipline.arbitrary._ +import cats.effect.laws.util.TestInstances._ +import cats.effect.laws.util.TestContext +import cats.laws.discipline.eq._ +import cats.mtl.laws.discipline.MonadStateTests +import cats.mtl.special.StateIO +import org.scalacheck._ + +class StateIOTests extends BaseSuite { + + implicit val ec = TestContext() + + implicit def arbStateIO[S: Arbitrary, A: Arbitrary: Cogen]: Arbitrary[StateIO[S, A]] = + Arbitrary(Gen.oneOf( + catsEffectLawsArbitraryForIO[A].arbitrary.map(StateIO.liftF[S, A]), + implicitly[Arbitrary[A]].arbitrary.map(a => StateIO.get[S].as(a)), + implicitly[Arbitrary[A]].arbitrary.flatMap(a => + implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.set(s).as(a))), + implicitly[Arbitrary[A]].arbitrary.flatMap(a => + implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.modify[S](_ => s).as(a))) + )) + + + implicit def eqStateIO[S: Arbitrary: Eq, A: Eq](implicit ctx: TestContext, e: Eq[IO[(S, A)]]): Eq[StateIO[S, A]] = + Eq.by(sio => { (s: S) => + Deferred.uncancelable[IO, (S, A)].flatMap(deferred => + IO(sio.unsafeRunAsync(s) { + case Right(sa) => deferred.complete(sa).unsafeRunSync() + case Left(t) => throw t + }).flatMap(_ => deferred.get)) + }) + + checkAll("StateIO[String, Int]", + MonadStateTests[StateIO[String, ?], String] + .monadState[Int]) + checkAll("MonadState[StateIO[String, ?]]", + SerializableTests.serializable(MonadState[StateIO[String, ?], String])) + +}