Skip to content

Commit

Permalink
New special and bench modules + StateIO
Browse files Browse the repository at this point in the history
  • Loading branch information
Luka Jacobowitz committed Oct 12, 2018
1 parent 830e1e8 commit a4d95d2
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 8 deletions.
36 changes: 36 additions & 0 deletions bench/src/main/scala/cats/mtl/bench/StateBench.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package cats.mtl.bench

import cats.data.StateT
import cats.effect.IO
import org.openjdk.jmh.annotations._
import java.util.concurrent.TimeUnit

import cats.mtl.special.StateIO

@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class StateBench {

def leftFlatMapSpecial(bound: Int): Int = {
def loop(i: Int): StateIO[Int, Int] =
if (i > bound) StateIO.pure(i)
else StateIO.pure[Int, Int](i + 1).flatMap(loop)

StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncS(0)
}

def leftFlatMapStateT(bound: Int): Int = {
def loop(i: Int): StateT[IO, Int, Int] =
if (i > bound) StateT.pure[IO, Int, Int](i)
else StateT.pure[IO, Int, Int](i + 1).flatMap(loop)

StateT.pure[IO, Int, Int](0).flatMap(loop).runS(0).unsafeRunSync()
}

@Benchmark
def leftAssociatedBindSpecial(): Int = leftFlatMapSpecial(100000)

@Benchmark
def leftAssociatedBindStateT(): Int = leftFlatMapStateT(100000)

}
31 changes: 26 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,32 @@ val laws = crossProject.crossType(CrossType.Pure)
val lawsJVM = laws.jvm
val lawsJS = laws.js

val special = crossProject.crossType(CrossType.Pure)
.dependsOn(core)
.settings(moduleName := "cats-mtl-special", name := "Cats MTL Special")
.settings(Settings.coreSettings:_*)
.settings(Dependencies.catsEffect: _*)
.jsSettings(Settings.commonJsSettings:_*)
.jvmSettings(Settings.commonJvmSettings:_*)
.jsSettings(coverageEnabled := false)

val specialJVM = special.jvm
val specialJS = special.js

lazy val bench = project.dependsOn(coreJVM, specialJVM)
.settings(moduleName := "cats-mtl-bench")
.settings(Settings.coreSettings:_*)
.settings(Publishing.noPublishSettings)
.settings(Settings.commonJvmSettings:_*)
.settings(coverageEnabled := false)
.enablePlugins(JmhPlugin)

val tests = crossProject.crossType(CrossType.Pure)
.dependsOn(core, laws)
.dependsOn(core, laws, special)
.settings(moduleName := "cats-mtl-tests")
.settings(Settings.coreSettings:_*)
.settings(Dependencies.catsBundle:_*)
.settings(Dependencies.catsEffectLaws:_*)
.settings(Dependencies.discipline:_*)
.settings(Dependencies.scalaCheck:_*)
.settings(Publishing.noPublishSettings:_*)
Expand Down Expand Up @@ -155,16 +176,16 @@ val catsMtlJVM = project.in(file(".catsJVM"))
.settings(Settings.coreSettings)
.settings(Settings.commonJvmSettings)
.settings(Publishing.noPublishSettings)
.aggregate(coreJVM, lawsJVM, testsJVM, jvm, docs)
.dependsOn(coreJVM, lawsJVM, testsJVM % "test-internal -> test", jvm)
.aggregate(coreJVM, lawsJVM, specialJVM, testsJVM, jvm, docs)
.dependsOn(coreJVM, lawsJVM, specialJVM, testsJVM % "test-internal -> test", jvm)

val catsMtlJS = project.in(file(".catsJS"))
.settings(moduleName := "cats-mtl")
.settings(Settings.coreSettings)
.settings(Settings.commonJsSettings)
.settings(Publishing.noPublishSettings)
.aggregate(coreJS, lawsJS, testsJS, js)
.dependsOn(coreJS, lawsJS, testsJS % "test-internal -> test", js)
.aggregate(coreJS, lawsJS, specialJS, testsJS, js)
.dependsOn(coreJS, lawsJS, specialJS, testsJS % "test-internal -> test", js)
.enablePlugins(ScalaJSPlugin)

val catsMtl = project.in(file("."))
Expand Down
10 changes: 7 additions & 3 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object Dependencies {
val simulacrum = "0.13.0"
val machinist = "0.6.5"
val cats = "1.4.0"
val shapeless = "2.3.3"
val catsEffect = "1.0.0"
}

val acyclic: Seq[Def.Setting[_]] = Def.settings(
Expand Down Expand Up @@ -54,8 +54,12 @@ object Dependencies {
"org.typelevel" %%% "cats-core" % Versions.cats
))

val shapeless: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"com.chuusai" %%% "shapeless" % Versions.shapeless
val catsEffect: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-effect" % Versions.catsEffect
))

val catsEffectLaws: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-effect-laws" % Versions.catsEffect
))

val simulacrumAndMachinist: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ addSbtPlugin("com.47deg" % "sbt-microsites" % "0.7.23")
addSbtPlugin("com.dwijnand" % "sbt-travisci" % "1.1.3")
addSbtPlugin("org.lyranthe.sbt" % "partial-unification" % "1.1.2")
addSbtPlugin("com.lucidchart" % "sbt-scalafmt-coursier" % "1.15")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4")
130 changes: 130 additions & 0 deletions special/src/main/scala/cats/mtl/special/StateIO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package cats.mtl.special

import cats.Monad
import cats.effect.concurrent.Ref
import cats.effect._
import cats.mtl.MonadState

/* Performant counterpart to StateT[IO, S, A] */
case class StateIO[S, A]private (private val value: IO[A]) {
def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunAsync(f)

def unsafeRunSync(s: S): (S, A) =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunSync()

def unsafeRunSyncA(s: S): A =
StateIO.refInstance[S].set(s).flatMap(_ => value).unsafeRunSync()

def unsafeRunSyncS(s: S): S =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(_ => StateIO.refInstance[S].get))
.unsafeRunSync()

def flatMap[B](f: A => StateIO[S, B]): StateIO[S, B] =
StateIO(value.flatMap(a => f(a).value))

def map[B](f: A => B): StateIO[S, B] =
StateIO(value.map(f))

def get: StateIO[S, S] =
flatMap(_ => StateIO.get[S])
}



object StateIO extends StateIOImpl {

def liftF[S, A](fa: IO[A]): StateIO[S, A] = StateIO(fa)

def get[S]: StateIO[S, S] = StateIO(refInstance[S].get)

def modify[S](f: S => S): StateIO[S, Unit] = StateIO(refInstance[S].update(f))

def set[S](s: S): StateIO[S, Unit] = StateIO(refInstance[S].set(s))

def pure[S, A](a: A): StateIO[S, A] = StateIO(IO.pure(a))

def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] =
StateIO(refInstance[S].get.flatMap(s => f(s).flatMap {
case (s, a) => refInstance[S].set(s).map(_ => a)
}))

implicit def monadStateStateIO[S](implicit ctx: ContextShift[IO]): MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] {
val monad: Monad[StateIO[S, ?]] = monadErrorStateIO

def get: StateIO[S, S] = StateIO.get[S]

def modify(f: S => S): StateIO[S, Unit] = StateIO.modify(f)

def set(s: S): StateIO[S, Unit] = StateIO.set(s)

def inspect[A](f: S => A): StateIO[S, A] = StateIO.get.map(f)
}

implicit def monadErrorStateIO[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] =
new Concurrent[StateIO[S, ?]] {

type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A]

def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] =
fa.flatMap(f)

override def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(IO.suspend(thunk.value))

def bracketCase[A, B](acquire: StateIO[S, A])
(use: A => StateIO[S, B])
(release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] =
StateIO(acquire.value.bracketCase(a => use(a).value)((a, ec) => release(a, ec).value))

def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] =
StateIO.liftF(IO.async(k))

def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] =
StateIO.liftF(IO.asyncF(cb => k(cb).value))

override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] =
StateIO.liftF(IO.cancelable[A](cb => k(cb).value))

def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] =
StateIO.liftF[S, Fiber[A]](fa.value.start.map(fiberT))


def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] =
StateIO(IO.racePair(fa.value, fb.value).map {
case Left((a, fib)) => Left((a, fiberT(fib)))
case Right((fib, b)) => Right((fiberT(fib), b))
})


override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = fa.map(f)

def pure[A](x: A): StateIO[S, A] = StateIO.pure(x)

def raiseError[A](e: Throwable): StateIO[S, A] = StateIO(IO.raiseError(e))

def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] =
StateIO(fa.value.handleErrorWith(e => f(e).value))

def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] =
StateIO(Monad[IO].tailRecM(a)(a => f(a).value))

protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] =
Fiber(StateIO(fiber.join), StateIO(fiber.cancel))
}
}



private[special] sealed trait StateIOImpl {

/* There be dragons */
private val ref: Ref[IO, Any] = Ref.unsafe[IO, Any](null)

private[special] def refInstance[S]: Ref[IO, S] = ref.asInstanceOf[Ref[IO, S]]
}

0 comments on commit a4d95d2

Please sign in to comment.