Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: New special and bench modules + StateIO #86

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions bench/src/main/scala/cats/mtl/bench/StateBench.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package cats.mtl.bench

import cats.implicits._
import cats.data.StateT
import cats.effect.IO
import org.openjdk.jmh.annotations._
import java.util.concurrent.TimeUnit

import cats.mtl.special.StateIO

@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class StateBench {

def leftFlatMapSpecial(bound: Int): Int = {
def loop(i: Int): StateIO[Int, Int] =
if (i > bound) StateIO.pure(i)
else StateIO.pure[Int, Int](i + 1).flatMap(loop)

StateIO.pure[Int, Int](0).flatMap(loop).unsafeRunSyncA(0)
}

def leftFlatMapStateT(bound: Int): Int = {
def loop(i: Int): StateT[IO, Int, Int] =
if (i > bound) StateT.pure[IO, Int, Int](i)
else StateT.pure[IO, Int, Int](i + 1).flatMap(loop)

StateT.pure[IO, Int, Int](0).flatMap(loop).runA(0).unsafeRunSync()
}

@Benchmark
def leftAssociatedBindSpecial(): Int = leftFlatMapSpecial(100000)

@Benchmark
def leftAssociatedBindStateT(): Int = leftFlatMapStateT(100000)

}
31 changes: 26 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,32 @@ val laws = crossProject.crossType(CrossType.Pure)
val lawsJVM = laws.jvm
val lawsJS = laws.js

val special = crossProject.crossType(CrossType.Pure)
.dependsOn(core)
.settings(moduleName := "cats-mtl-special", name := "Cats MTL Special")
.settings(Settings.coreSettings:_*)
.settings(Dependencies.catsEffect: _*)
.jsSettings(Settings.commonJsSettings:_*)
.jvmSettings(Settings.commonJvmSettings:_*)
.jsSettings(coverageEnabled := false)

val specialJVM = special.jvm
val specialJS = special.js

lazy val bench = project.dependsOn(coreJVM, specialJVM)
.settings(moduleName := "cats-mtl-bench")
.settings(Settings.coreSettings:_*)
.settings(Publishing.noPublishSettings)
.settings(Settings.commonJvmSettings:_*)
.settings(coverageEnabled := false)
.enablePlugins(JmhPlugin)

val tests = crossProject.crossType(CrossType.Pure)
.dependsOn(core, laws)
.dependsOn(core, laws, special)
.settings(moduleName := "cats-mtl-tests")
.settings(Settings.coreSettings:_*)
.settings(Dependencies.catsBundle:_*)
.settings(Dependencies.catsEffectLaws:_*)
.settings(Dependencies.discipline:_*)
.settings(Dependencies.scalaCheck:_*)
.settings(Publishing.noPublishSettings:_*)
Expand Down Expand Up @@ -155,16 +176,16 @@ val catsMtlJVM = project.in(file(".catsJVM"))
.settings(Settings.coreSettings)
.settings(Settings.commonJvmSettings)
.settings(Publishing.noPublishSettings)
.aggregate(coreJVM, lawsJVM, testsJVM, jvm, docs)
.dependsOn(coreJVM, lawsJVM, testsJVM % "test-internal -> test", jvm)
.aggregate(coreJVM, lawsJVM, specialJVM, testsJVM, jvm, docs)
.dependsOn(coreJVM, lawsJVM, specialJVM, testsJVM % "test-internal -> test", jvm)

val catsMtlJS = project.in(file(".catsJS"))
.settings(moduleName := "cats-mtl")
.settings(Settings.coreSettings)
.settings(Settings.commonJsSettings)
.settings(Publishing.noPublishSettings)
.aggregate(coreJS, lawsJS, testsJS, js)
.dependsOn(coreJS, lawsJS, testsJS % "test-internal -> test", js)
.aggregate(coreJS, lawsJS, specialJS, testsJS, js)
.dependsOn(coreJS, lawsJS, specialJS, testsJS % "test-internal -> test", js)
.enablePlugins(ScalaJSPlugin)

val catsMtl = project.in(file("."))
Expand Down
10 changes: 7 additions & 3 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object Dependencies {
val simulacrum = "0.13.0"
val machinist = "0.6.5"
val cats = "1.4.0"
val shapeless = "2.3.3"
val catsEffect = "1.0.0"
}

val acyclic: Seq[Def.Setting[_]] = Def.settings(
Expand Down Expand Up @@ -54,8 +54,12 @@ object Dependencies {
"org.typelevel" %%% "cats-core" % Versions.cats
))

val shapeless: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"com.chuusai" %%% "shapeless" % Versions.shapeless
val catsEffect: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-effect" % Versions.catsEffect
))

val catsEffectLaws: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-effect-laws" % Versions.catsEffect
))

val simulacrumAndMachinist: Seq[Setting[_]] = Def.settings(libraryDependencies ++= Seq(
Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ addSbtPlugin("com.47deg" % "sbt-microsites" % "0.7.23")
addSbtPlugin("com.dwijnand" % "sbt-travisci" % "1.1.3")
addSbtPlugin("org.lyranthe.sbt" % "partial-unification" % "1.1.2")
addSbtPlugin("com.lucidchart" % "sbt-scalafmt-coursier" % "1.15")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.4")
8 changes: 8 additions & 0 deletions special/src/main/scala/cats/mtl/special/Newtype2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package cats.mtl.special

private[special] trait Newtype2 { self =>
private[special] type Base
private[special] trait Tag extends Any
type Type[A, B] <: Base with Tag
}

144 changes: 144 additions & 0 deletions special/src/main/scala/cats/mtl/special/StateIO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package cats.mtl.special

import cats.Monad
import cats.effect.concurrent.Ref
import cats.effect._
import cats.mtl.MonadState


/* Performant counterpart to StateT[IO, S, A] */
object StateIO extends StateIOInstances with StateIOImpl with Newtype2 {

private[cats] def wrap[S, A](s: IO[A]): Type[S, A] =
s.asInstanceOf[Type[S, A]]

private[cats] def unwrap[S, A](e: Type[S, A]): IO[A] =
e.asInstanceOf[IO[A]]

@inline
def liftF[S, A](fa: IO[A]): StateIO[S, A] = wrap(fa)

def get[S]: StateIO[S, S] = liftF(refInstance[S].get)

def modify[S](f: S => S): StateIO[S, Unit] = liftF(refInstance[S].update(f))

def set[S](s: S): StateIO[S, Unit] = liftF(refInstance[S].set(s))

def pure[S, A](a: A): StateIO[S, A] = liftF(IO.pure(a))

def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] =
liftF(refInstance[S].get.flatMap(s => f(s).flatMap {
case (s, a) => refInstance[S].set(s).map(_ => a)
}))

implicit def stateIOOps[S, A](v: StateIO[S, A]): StateIOOps[S, A] =
new StateIOOps[S, A](v)
}

private[special] class StateIOOps[S, A](val sp: StateIO[S, A]) extends AnyVal {

def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit =
StateIO.refInstance[S].set(s)
.flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunAsync(f)

def unsafeRunSync(s: S): (S, A) =
StateIO.refInstance[S].set(s)
.flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunSync()

def unsafeRunSyncA(s: S): A =
StateIO.refInstance[S].set(s).flatMap(_ => StateIO.unwrap(sp)).unsafeRunSync()

def unsafeRunSyncS(s: S): S =
StateIO.refInstance[S].set(s)
.flatMap(_ => StateIO.unwrap(sp).flatMap(_ => StateIO.refInstance[S].get))
.unsafeRunSync()
}


private[special] abstract class StateIOAsync[S] extends Async[StateIO[S, ?]] {

def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] =
StateIO.wrap(StateIO.unwrap(fa).flatMap(a => StateIO.unwrap(f(a))))

def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(IO.suspend(StateIO.unwrap(thunk)))

def bracketCase[A, B](acquire: StateIO[S, A])
(use: A => StateIO[S, B])
(release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] =
StateIO.liftF(StateIO.unwrap(acquire)
.bracketCase(a => StateIO.unwrap(use(a)))((a, ec) => StateIO.unwrap(release(a, ec))))

def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] =
StateIO.liftF(IO.async(k))

def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] =
StateIO.liftF(IO.asyncF(cb => StateIO.unwrap(k(cb))))

override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] =
StateIO.liftF(StateIO.unwrap(fa).map(f))

def pure[A](x: A): StateIO[S, A] = StateIO.pure(x)

def raiseError[A](e: Throwable): StateIO[S, A] = StateIO.liftF(IO.raiseError(e))

def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(StateIO.unwrap(fa).handleErrorWith(e => StateIO.unwrap(f(e))))

def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] =
StateIO.liftF(Monad[IO].tailRecM(a)(a => StateIO.unwrap(f(a))))
}

private[special] sealed abstract class StateIOInstances extends StateIOInstances0 {
implicit def stateIOMonadState[S]: MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] {
val monad: Monad[StateIO[S, ?]] = stateIOAsync[S]

def get: StateIO[S, S] = StateIO.get[S]

def modify(f: S => S): StateIO[S, Unit] = StateIO.modify(f)

def set(s: S): StateIO[S, Unit] = StateIO.set(s)

def inspect[A](f: S => A): StateIO[S, A] = monad.map(StateIO.get)(f)
}

implicit def stateIOConcurrent[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] =
new StateIOAsync[S] with Concurrent[StateIO[S, ?]] {

type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A]

override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] =
StateIO.liftF(IO.cancelable[A](cb => StateIO.unwrap(k(cb))))

def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] =
StateIO.liftF[S, Fiber[A]](StateIO.unwrap(fa).start.map(fiberT))


def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] =
StateIO.liftF(IO.racePair(StateIO.unwrap(fa), StateIO.unwrap(fb)).map {
case Left((a, fib)) => Left((a, fiberT(fib)))
case Right((fib, b)) => Right((fiberT(fib), b))
})

protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] =
Fiber(StateIO.liftF(fiber.join), StateIO.liftF(fiber.cancel))
}
}

private[special] sealed abstract class StateIOInstances0 {
implicit def stateIOAsync[S]: Async[StateIO[S, ?]] =
new StateIOAsync[S]{}
}


private[special] sealed trait StateIOImpl {

/* There be dragons */
// scalastyle:off null
private val ref: Ref[IO, Any] = Ref.unsafe[IO, Any](null)
// scalastyle:on null

private[special] def refInstance[S]: Ref[IO, S] = ref.asInstanceOf[Ref[IO, S]]
}
5 changes: 5 additions & 0 deletions special/src/main/scala/cats/mtl/special/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package cats.mtl

package object special {
type StateIO[S, A] = StateIO.Type[S, A]
}
49 changes: 49 additions & 0 deletions tests/src/test/scala/cats/mtl/tests/StateIOTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package cats
package mtl
package tests


import cats.implicits._
import cats.laws.discipline.arbitrary._
import cats.laws.discipline._
import cats.effect.IO
import cats.effect.concurrent.Deferred
import cats.effect.laws.discipline.arbitrary._
import cats.effect.laws.util.TestInstances._
import cats.effect.laws.util.TestContext
import cats.laws.discipline.eq._
import cats.mtl.laws.discipline.MonadStateTests
import cats.mtl.special.StateIO
import org.scalacheck._

class StateIOTests extends BaseSuite {

implicit val ec = TestContext()

implicit def arbStateIO[S: Arbitrary, A: Arbitrary: Cogen]: Arbitrary[StateIO[S, A]] =
Arbitrary(Gen.oneOf(
catsEffectLawsArbitraryForIO[A].arbitrary.map(StateIO.liftF[S, A]),
implicitly[Arbitrary[A]].arbitrary.map(a => StateIO.get[S].as(a)),
implicitly[Arbitrary[A]].arbitrary.flatMap(a =>
implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.set(s).as(a))),
implicitly[Arbitrary[A]].arbitrary.flatMap(a =>
implicitly[Arbitrary[S]].arbitrary.map(s => StateIO.modify[S](_ => s).as(a)))
))


implicit def eqStateIO[S: Arbitrary: Eq, A: Eq](implicit ctx: TestContext, e: Eq[IO[(S, A)]]): Eq[StateIO[S, A]] =
Eq.by(sio => { (s: S) =>
Deferred.uncancelable[IO, (S, A)].flatMap(deferred =>
IO(sio.unsafeRunAsync(s) {
case Right(sa) => deferred.complete(sa).unsafeRunSync()
case Left(t) => throw t
}).flatMap(_ => deferred.get))
})

checkAll("StateIO[String, Int]",
MonadStateTests[StateIO[String, ?], String]
.monadState[Int])
checkAll("MonadState[StateIO[String, ?]]",
SerializableTests.serializable(MonadState[StateIO[String, ?], String]))

}