Skip to content

Commit

Permalink
Convert to newtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Luka Jacobowitz committed Oct 12, 2018
1 parent a4d95d2 commit 5c4caaf
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 62 deletions.
1 change: 1 addition & 0 deletions bench/src/main/scala/cats/mtl/bench/StateBench.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package cats.mtl.bench

import cats.implicits._
import cats.data.StateT
import cats.effect.IO
import org.openjdk.jmh.annotations._
Expand Down
8 changes: 8 additions & 0 deletions special/src/main/scala/cats/mtl/special/Newtype2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package cats.mtl.special

private[special] trait Newtype2 { self =>
private[special] type Base
private[special] trait Tag extends Any
type Type[A, B] <: Base with Tag
}

136 changes: 74 additions & 62 deletions special/src/main/scala/cats/mtl/special/StateIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,120 +5,132 @@ import cats.effect.concurrent.Ref
import cats.effect._
import cats.mtl.MonadState


/* Performant counterpart to StateT[IO, S, A] */
case class StateIO[S, A]private (private val value: IO[A]) {
object StateIO extends StateIOInstances with StateIOImpl with Newtype2 {

@inline private[cats] def wrap[S, A](s: IO[A]): Type[S, A] =
s.asInstanceOf[Type[S, A]]

@inline private[cats] def unwrap[S, A](e: Type[S, A]): IO[A] =
e.asInstanceOf[IO[A]]

@inline
def liftF[S, A](fa: IO[A]): StateIO[S, A] = wrap(fa)

def get[S]: StateIO[S, S] = liftF(refInstance[S].get)

def modify[S](f: S => S): StateIO[S, Unit] = liftF(refInstance[S].update(f))

def set[S](s: S): StateIO[S, Unit] = liftF(refInstance[S].set(s))

def pure[S, A](a: A): StateIO[S, A] = liftF(IO.pure(a))

def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] =
liftF(refInstance[S].get.flatMap(s => f(s).flatMap {
case (s, a) => refInstance[S].set(s).map(_ => a)
}))

implicit def stateIOOps[S, A](v: StateIO[S, A]): StateIOOps[S, A] =
new StateIOOps[S, A](v)
}

private[special] class StateIOOps[S, A](val sp: StateIO[S, A]) extends AnyVal {

def unsafeRunAsync(s: S)(f: Either[Throwable, (S, A)] => Unit): Unit =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunAsync(f)

def unsafeRunSync(s: S): (S, A) =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.flatMap(_ => StateIO.unwrap(sp).flatMap(a => StateIO.refInstance[S].get.map(s => (s, a))))
.unsafeRunSync()

def unsafeRunSyncA(s: S): A =
StateIO.refInstance[S].set(s).flatMap(_ => value).unsafeRunSync()
StateIO.refInstance[S].set(s).flatMap(_ => StateIO.unwrap(sp)).unsafeRunSync()

def unsafeRunSyncS(s: S): S =
StateIO.refInstance[S].set(s)
.flatMap(_ => value.flatMap(_ => StateIO.refInstance[S].get))
.flatMap(_ => StateIO.unwrap(sp).flatMap(_ => StateIO.refInstance[S].get))
.unsafeRunSync()
}

def flatMap[B](f: A => StateIO[S, B]): StateIO[S, B] =
StateIO(value.flatMap(a => f(a).value))

def map[B](f: A => B): StateIO[S, B] =
StateIO(value.map(f))
private[special] abstract class StateIOAsync[S] extends Async[StateIO[S, ?]] {
@inline
def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] =
StateIO.wrap(StateIO.unwrap(fa).flatMap(a => StateIO.unwrap(f(a))))

def get: StateIO[S, S] =
flatMap(_ => StateIO.get[S])
}
def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(IO.suspend(StateIO.unwrap(thunk)))

def bracketCase[A, B](acquire: StateIO[S, A])
(use: A => StateIO[S, B])
(release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] =
StateIO.liftF(StateIO.unwrap(acquire)
.bracketCase(a => StateIO.unwrap(use(a)))((a, ec) => StateIO.unwrap(release(a, ec))))

def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] =
StateIO.liftF(IO.async(k))

object StateIO extends StateIOImpl {
def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] =
StateIO.liftF(IO.asyncF(cb => StateIO.unwrap(k(cb))))

def liftF[S, A](fa: IO[A]): StateIO[S, A] = StateIO(fa)
override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] =
StateIO.liftF(StateIO.unwrap(fa).map(f))

def get[S]: StateIO[S, S] = StateIO(refInstance[S].get)
def pure[A](x: A): StateIO[S, A] = StateIO.pure(x)

def modify[S](f: S => S): StateIO[S, Unit] = StateIO(refInstance[S].update(f))
def raiseError[A](e: Throwable): StateIO[S, A] = StateIO.liftF(IO.raiseError(e))

def set[S](s: S): StateIO[S, Unit] = StateIO(refInstance[S].set(s))
def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(StateIO.unwrap(fa).handleErrorWith(e => StateIO.unwrap(f(e))))

def pure[S, A](a: A): StateIO[S, A] = StateIO(IO.pure(a))

def create[S, A](f: S => IO[(S, A)]): StateIO[S, A] =
StateIO(refInstance[S].get.flatMap(s => f(s).flatMap {
case (s, a) => refInstance[S].set(s).map(_ => a)
}))
def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] =
StateIO.liftF(Monad[IO].tailRecM(a)(a => StateIO.unwrap(f(a))))
}

implicit def monadStateStateIO[S](implicit ctx: ContextShift[IO]): MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] {
val monad: Monad[StateIO[S, ?]] = monadErrorStateIO
private[special] sealed abstract class StateIOInstances extends StateIOInstances0 {
implicit def stateIOMonadState[S]: MonadState[StateIO[S, ?], S] = new MonadState[StateIO[S, ?], S] {
val monad: Monad[StateIO[S, ?]] = stateIOAsync[S]

def get: StateIO[S, S] = StateIO.get[S]

def modify(f: S => S): StateIO[S, Unit] = StateIO.modify(f)

def set(s: S): StateIO[S, Unit] = StateIO.set(s)

def inspect[A](f: S => A): StateIO[S, A] = StateIO.get.map(f)
def inspect[A](f: S => A): StateIO[S, A] = monad.map(StateIO.get)(f)
}

implicit def monadErrorStateIO[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] =
new Concurrent[StateIO[S, ?]] {
implicit def stateIOConcurrent[S](implicit ctx: ContextShift[IO]): Concurrent[StateIO[S, ?]] =
new StateIOAsync[S] with Concurrent[StateIO[S, ?]] {

type Fiber[A] = cats.effect.Fiber[StateIO[S, ?], A]

def flatMap[A, B](fa: StateIO[S, A])(f: A => StateIO[S, B]): StateIO[S, B] =
fa.flatMap(f)

override def suspend[A](thunk: => StateIO[S, A]): StateIO[S, A] =
StateIO.liftF(IO.suspend(thunk.value))

def bracketCase[A, B](acquire: StateIO[S, A])
(use: A => StateIO[S, B])
(release: (A, ExitCase[Throwable]) => StateIO[S, Unit]): StateIO[S, B] =
StateIO(acquire.value.bracketCase(a => use(a).value)((a, ec) => release(a, ec).value))

def async[A](k: (Either[Throwable, A] => Unit) => Unit): StateIO[S, A] =
StateIO.liftF(IO.async(k))

def asyncF[A](k: (Either[Throwable, A] => Unit) => StateIO[S, Unit]): StateIO[S, A] =
StateIO.liftF(IO.asyncF(cb => k(cb).value))

override def cancelable[A](k: (Either[Throwable, A] => Unit) => CancelToken[StateIO[S, ?]]): StateIO[S, A] =
StateIO.liftF(IO.cancelable[A](cb => k(cb).value))
StateIO.liftF(IO.cancelable[A](cb => StateIO.unwrap(k(cb))))

def start[A](fa: StateIO[S, A]): StateIO[S, Fiber[A]] =
StateIO.liftF[S, Fiber[A]](fa.value.start.map(fiberT))
StateIO.liftF[S, Fiber[A]](StateIO.unwrap(fa).start.map(fiberT))


def racePair[A, B](fa: StateIO[S, A], fb: StateIO[S, B]): StateIO[S, Either[(A, Fiber[B]), (Fiber[A], B)]] =
StateIO(IO.racePair(fa.value, fb.value).map {
StateIO.liftF(IO.racePair(StateIO.unwrap(fa), StateIO.unwrap(fb)).map {
case Left((a, fib)) => Left((a, fiberT(fib)))
case Right((fib, b)) => Right((fiberT(fib), b))
})


override def map[A, B](fa: StateIO[S, A])(f: A => B): StateIO[S, B] = fa.map(f)

def pure[A](x: A): StateIO[S, A] = StateIO.pure(x)

def raiseError[A](e: Throwable): StateIO[S, A] = StateIO(IO.raiseError(e))

def handleErrorWith[A](fa: StateIO[S, A])(f: Throwable => StateIO[S, A]): StateIO[S, A] =
StateIO(fa.value.handleErrorWith(e => f(e).value))

def tailRecM[A, B](a: A)(f: A => StateIO[S, Either[A, B]]): StateIO[S, B] =
StateIO(Monad[IO].tailRecM(a)(a => f(a).value))

protected def fiberT[A](fiber: cats.effect.Fiber[IO, A]): Fiber[A] =
Fiber(StateIO(fiber.join), StateIO(fiber.cancel))
Fiber(StateIO.liftF(fiber.join), StateIO.liftF(fiber.cancel))
}
}

private[special] sealed abstract class StateIOInstances0 {
implicit def stateIOAsync[S]: Async[StateIO[S, ?]] =
new StateIOAsync[S]{}
}


private[special] sealed trait StateIOImpl {
Expand Down
5 changes: 5 additions & 0 deletions special/src/main/scala/cats/mtl/special/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package cats.mtl

package object special {
type StateIO[S, A] = StateIO.Type[S, A]
}

0 comments on commit 5c4caaf

Please sign in to comment.