diff --git a/bench/src/main/scala/cats/mtl/bench/StateBench.scala b/bench/src/main/scala/cats/mtl/bench/StateBench.scala new file mode 100644 index 00000000..401cf253 --- /dev/null +++ b/bench/src/main/scala/cats/mtl/bench/StateBench.scala @@ -0,0 +1,37 @@ +package cats.mtl.bench + +import cats.implicits._ +import cats.data.StateT +import cats.effect.IO +import org.openjdk.jmh.annotations._ +import java.util.concurrent.TimeUnit + +import cats.mtl.special.StateIO + +@BenchmarkMode(Array(Mode.Throughput)) +@OutputTimeUnit(TimeUnit.SECONDS) +class StateBench { + + def leftFlatMapSpecial(bound: Int): Int = { + def loop(i: Int): StateIO[Int, Int] = + if (i > bound) StateIO.pure(i) + else StateIO.pure[Int, Int](i + 1).flatMap(loop) + + StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncA(0) + } + + def leftFlatMapStateT(bound: Int): Int = { + def loop(i: Int): StateT[IO, Int, Int] = + if (i > bound) StateT.pure[IO, Int, Int](i) + else StateT.pure[IO, Int, Int](i + 1).flatMap(loop) + + StateT.pure[IO, Int, Int](0).flatMap(loop).runA(0).unsafeRunSync() + } + + @Benchmark + def leftAssociatedBindSpecial(): Int = leftFlatMapSpecial(100000) + + @Benchmark + def leftAssociatedBindStateT(): Int = leftFlatMapStateT(100000) + +} diff --git a/build.sbt b/build.sbt index 8add1bf5..e622e004 100644 --- a/build.sbt +++ b/build.sbt @@ -118,11 +118,32 @@ val laws = crossProject.crossType(CrossType.Pure) val lawsJVM = laws.jvm val lawsJS = laws.js +val special = crossProject.crossType(CrossType.Pure) + .dependsOn(core) + .settings(moduleName := "cats-mtl-special", name := "Cats MTL Special") + .settings(Settings.coreSettings:_*) + .settings(Dependencies.catsEffect: _*) + .jsSettings(Settings.commonJsSettings:_*) + .jvmSettings(Settings.commonJvmSettings:_*) + .jsSettings(coverageEnabled := false) + +val specialJVM = special.jvm +val specialJS = special.js + +lazy val bench = project.dependsOn(coreJVM, specialJVM) + .settings(moduleName := "cats-mtl-bench") + .settings(Settings.coreSettings:_*) + .settings(Publishing.noPublishSettings) + .settings(Settings.commonJvmSettings:_*) + .settings(coverageEnabled := false) + .enablePlugins(JmhPlugin) + val tests = crossProject.crossType(CrossType.Pure) - .dependsOn(core, laws) + .dependsOn(core, laws, special) .settings(moduleName := "cats-mtl-tests") .settings(Settings.coreSettings:_*) .settings(Dependencies.catsBundle:_*) + .settings(Dependencies.catsEffectLaws:_*) .settings(Dependencies.discipline:_*) .settings(Dependencies.scalaCheck:_*) .settings(Publishing.noPublishSettings:_*) @@ -155,16 +176,16 @@ val catsMtlJVM = project.in(file(".catsJVM")) .settings(Settings.coreSettings) .settings(Settings.commonJvmSettings) .settings(Publishing.noPublishSettings) - .aggregate(coreJVM, lawsJVM, testsJVM, jvm, docs) - .dependsOn(coreJVM, lawsJVM, testsJVM % "test-internal -> test", jvm) + .aggregate(coreJVM, lawsJVM, specialJVM, testsJVM, jvm, docs) + .dependsOn(coreJVM, lawsJVM, specialJVM, testsJVM % "test-internal -> test", jvm) val catsMtlJS = project.in(file(".catsJS")) .settings(moduleName := "cats-mtl") .settings(Settings.coreSettings) .settings(Settings.commonJsSettings) .settings(Publishing.noPublishSettings) - .aggregate(coreJS, lawsJS, testsJS, js) - .dependsOn(coreJS, lawsJS, testsJS % "test-internal -> test", js) + .aggregate(coreJS, lawsJS, specialJS, testsJS, js) + .dependsOn(coreJS, lawsJS, specialJS, testsJS % "test-internal -> test", js) .enablePlugins(ScalaJSPlugin) val catsMtl = project.in(file(".")) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index a8d3c47b..1637f8db 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -15,7 +15,7 @@ object Dependencies { val simulacrum = "0.13.0" val machinist = "0.6.5" val cats = "1.4.0" - val shapeless = "2.3.3" + val catsEffect = "1.0.0" } val acyclic: Seq[Def.Setting[_]] = Def.settings( @@ -54,8 +54,12 @@ object Dependencies { "org.typelevel" %%% "cats-core" % Versions.cats )) - val shapeless: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( - "com.chuusai" %%% "shapeless" % Versions.shapeless + val catsEffect: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-effect" % Versions.catsEffect + )) + + val catsEffectLaws: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-effect-laws" % Versions.catsEffect )) val simulacrumAndMachinist: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index 71d603af..47f2e6bc 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -12,3 +12,4 @@ addSbtPlugin("com.47deg" % "sbt-microsites" % "0.7.23") addSbtPlugin("com.dwijnand" % "sbt-travisci" % "1.1.3") addSbtPlugin("org.lyranthe.sbt" % "partial-unification" % "1.1.2") addSbtPlugin("com.lucidchart" % "sbt-scalafmt-coursier" % "1.15") +addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4") diff --git a/special/src/main/scala/cats/mtl/special/Newtype2.scala b/special/src/main/scala/cats/mtl/special/Newtype2.scala new file mode 100644 index 00000000..a82e26f2 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/Newtype2.scala @@ -0,0 +1,8 @@ +package cats.mtl.special + +private[special] trait Newtype2 { self => + private[special] type Base + private[special] trait Tag extends Any + type Type[A, B] <: Base with Tag +} + diff --git a/special/src/main/scala/cats/mtl/special/StateIO.scala b/special/src/main/scala/cats/mtl/special/StateIO.scala new file mode 100644 index 00000000..5912c563 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/StateIO.scala @@ -0,0 +1,144 @@ +package cats.mtl.special + +import cats.Monad +import cats.effect.concurrent.Ref +import cats.effect._ +import cats.mtl.MonadState + + +/* Performant counterpart to StateT[IO, S, A] */ +object StateIO extends StateIOInstances with StateIOImpl with Newtype2 { + + private[cats] def wrap[S, A](s: IO[A]): Type[S, A] = + s.asInstanceOf[Type[S, A]] + + private[cats] def unwrap[S, A](e: Type[S, A]): IO[A] = + e.asInstanceOf[IO[A]] + + @inline + def liftF[S, A](fa: IO[A]): StateIO[S, A] = wrap(fa) + + def get[S]: StateIO[S, S] = liftF(refInstance[S].get) + + def modify[S](f: S => S): StateIO[S, Unit] = liftF(refInstance[S].update(f)) + + def set[S](s: S): StateIO[S, Unit] = liftF(refInstance[S].set(s)) + + def pure[S, A](a: A): StateIO[S, A] = liftF(IO.pure(a)) + + def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] = + liftF(refInstance[S].get.flatMap(s => f(s).flatMap { + case (s, a) => refInstance[S].set(s).map(_ => a) + })) + + implicit def stateIOOps[S, A](v: StateIO[S, A]): StateIOOps[S, A] = + new StateIOOps[S, A](v) +} + +private[special] class StateIOOps[S, A](val sp: StateIO[S, A]) extends AnyVal { + + def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit = + StateIO.refInstance[S].set(s) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .unsafeRunAsync(f) + + def unsafeRunSync(s: S): (S, A) = + StateIO.refInstance[S].set(s) + .flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a)))) + .unsafeRunSync() + + def unsafeRunSyncA(s: S): A = + StateIO.refInstance[S].set(s).flatMap(_ => StateIO.unwrap(sp)).unsafeRunSync() + + def unsafeRunSyncS(s: S): S = + StateIO.refInstance[S].set(s) + .flatMap(_ => StateIO.unwrap(sp).flatMap(_ => StateIO.refInstance[S].get)) + .unsafeRunSync() +} + + +private[special] abstract class StateIOAsync[S] extends Async[StateIO[S, ?]] { + + def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] = + StateIO.wrap(StateIO.unwrap(fa).flatMap(a => StateIO.unwrap(f(a)))) + + def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(IO.suspend(StateIO.unwrap(thunk))) + + def bracketCase[A, B](acquire: StateIO[S, A]) + (use: A => StateIO[S, B]) + (release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(acquire) + .bracketCase(a => StateIO.unwrap(use(a)))((a, ec) => StateIO.unwrap(release(a, ec)))) + + def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] = + StateIO.liftF(IO.async(k)) + + def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] = + StateIO.liftF(IO.asyncF(cb => StateIO.unwrap(k(cb)))) + + override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = + StateIO.liftF(StateIO.unwrap(fa).map(f)) + + def pure[A](x: A): StateIO[S, A] = StateIO.pure(x) + + def raiseError[A](e: Throwable): StateIO[S, A] = StateIO.liftF(IO.raiseError(e)) + + def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] = + StateIO.liftF(StateIO.unwrap(fa).handleErrorWith(e => StateIO.unwrap(f(e)))) + + def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] = + StateIO.liftF(Monad[IO].tailRecM(a)(a => StateIO.unwrap(f(a)))) +} + +private[special] sealed abstract class StateIOInstances extends StateIOInstances0 { + implicit def stateIOMonadState[S]: MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] { + val monad: Monad[StateIO[S, ?]] = stateIOAsync[S] + + def get: StateIO[S, S] = StateIO.get[S] + + def modify(f: S => S): StateIO[S, Unit] = StateIO.modify(f) + + def set(s: S): StateIO[S, Unit] = StateIO.set(s) + + def inspect[A](f: S => A): StateIO[S, A] = monad.map(StateIO.get)(f) + } + + implicit def stateIOConcurrent[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] = + new StateIOAsync[S] with Concurrent[StateIO[S, ?]] { + + type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A] + + override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] = + StateIO.liftF(IO.cancelable[A](cb => StateIO.unwrap(k(cb)))) + + def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] = + StateIO.liftF[S, Fiber[A]](StateIO.unwrap(fa).start.map(fiberT)) + + + def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] = + StateIO.liftF(IO.racePair(StateIO.unwrap(fa), StateIO.unwrap(fb)).map { + case Left((a, fib)) => Left((a, fiberT(fib))) + case Right((fib, b)) => Right((fiberT(fib), b)) + }) + + protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] = + Fiber(StateIO.liftF(fiber.join), StateIO.liftF(fiber.cancel)) + } +} + +private[special] sealed abstract class StateIOInstances0 { + implicit def stateIOAsync[S]: Async[StateIO[S, ?]] = + new StateIOAsync[S]{} +} + + +private[special] sealed trait StateIOImpl { + + /* There be dragons */ + // scalastyle:off null + private val ref: Ref[IO, Any] = Ref.unsafe[IO, Any](null) + // scalastyle:on null + + private[special] def refInstance[S]: Ref[IO, S] = ref.asInstanceOf[Ref[IO, S]] +} diff --git a/special/src/main/scala/cats/mtl/special/package.scala b/special/src/main/scala/cats/mtl/special/package.scala new file mode 100644 index 00000000..d4d0cff6 --- /dev/null +++ b/special/src/main/scala/cats/mtl/special/package.scala @@ -0,0 +1,5 @@ +package cats.mtl + +package object special { + type StateIO[S, A] = StateIO.Type[S, A] +} diff --git a/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala b/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala new file mode 100644 index 00000000..2381913d --- /dev/null +++ b/tests/src/test/scala/cats/mtl/tests/StateIOTests.scala @@ -0,0 +1,49 @@ +package cats +package mtl +package tests + + +import cats.implicits._ +import cats.laws.discipline.arbitrary._ +import cats.laws.discipline._ +import cats.effect.IO +import cats.effect.concurrent.Deferred +import cats.effect.laws.discipline.arbitrary._ +import cats.effect.laws.util.TestInstances._ +import cats.effect.laws.util.TestContext +import cats.laws.discipline.eq._ +import cats.mtl.laws.discipline.MonadStateTests +import cats.mtl.special.StateIO +import org.scalacheck._ + +class StateIOTests extends BaseSuite { + + implicit val ec = TestContext() + + implicit def arbStateIO[S: Arbitrary, A: Arbitrary: Cogen]: Arbitrary[StateIO[S, A]] = + Arbitrary(Gen.oneOf( + catsEffectLawsArbitraryForIO[A].arbitrary.map(StateIO.liftF[S, A]), + implicitly[Arbitrary[A]].arbitrary.map(a => StateIO.get[S].as(a)), + implicitly[Arbitrary[A]].arbitrary.flatMap(a => + implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.set(s).as(a))), + implicitly[Arbitrary[A]].arbitrary.flatMap(a => + implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.modify[S](_ => s).as(a))) + )) + + + implicit def eqStateIO[S: Arbitrary: Eq, A: Eq](implicit ctx: TestContext, e: Eq[IO[(S, A)]]): Eq[StateIO[S, A]] = + Eq.by(sio => { (s: S) => + Deferred.uncancelable[IO, (S, A)].flatMap(deferred => + IO(sio.unsafeRunAsync(s) { + case Right(sa) => deferred.complete(sa).unsafeRunSync() + case Left(t) => throw t + }).flatMap(_ => deferred.get)) + }) + + checkAll("StateIO[String, Int]", + MonadStateTests[StateIO[String, ?], String] + .monadState[Int]) + checkAll("MonadState[StateIO[String, ?]]", + SerializableTests.serializable(MonadState[StateIO[String, ?], String])) + +}