diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index 84e688d8eb..fa4d05079c 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -1,7 +1,7 @@ package cats import simulacrum.typeclass -import syntax.either._ + /** * Monad. * @@ -30,7 +30,7 @@ import syntax.either._ Left(G.combineK(xs, G.pure(bv))) } }, - ifFalse = pure(xs.asRight[G[A]]) + ifFalse = pure(Right(xs)) )) } @@ -76,28 +76,37 @@ import syntax.either._ * Execute an action repeatedly until its result fails to satisfy the given predicate * and return that result, discarding all others. */ - def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] = { + def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] = flatMap(f) { i => - tailRecM(i) { a => - if (p(a)) - map(f)(_.asLeft[A]) - else pure(a.asRight[A]) - } + iterateWhileM(i)(_ => f)(p) } - } /** * Execute an action repeatedly until its result satisfies the given predicate * and return that result, discarding all others. */ - def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] = { + def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] = flatMap(f) { i => - tailRecM(i) { a => - if (p(a)) - pure(a.asRight[A]) - else map(f)(_.asLeft[A]) - } + iterateUntilM(i)(_ => f)(p) } - } + + /** + * Apply a monadic function iteratively until its result fails + * to satisfy the given predicate and return that result. + */ + def iterateWhileM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + tailRecM(init) { a => + if (p(a)) + map(f(a))(Left(_)) + else + pure(Right(a)) + } + + /** + * Apply a monadic function iteratively until its result satisfies + * the given predicate and return that result. + */ + def iterateUntilM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + iterateWhileM(init)(f)(!p(_)) } diff --git a/core/src/main/scala/cats/syntax/monad.scala b/core/src/main/scala/cats/syntax/monad.scala index 2772c5a8ed..efdd15235f 100644 --- a/core/src/main/scala/cats/syntax/monad.scala +++ b/core/src/main/scala/cats/syntax/monad.scala @@ -3,6 +3,9 @@ package syntax trait MonadSyntax { implicit final def catsSyntaxMonad[F[_], A](fa: F[A]): MonadOps[F, A] = new MonadOps(fa) + + implicit final def catsSyntaxMonadIdOps[A](a: A): MonadIdOps[A] = + new MonadIdOps[A](a) } final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { @@ -13,3 +16,16 @@ final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p) def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) } + +final class MonadIdOps[A](val a: A) extends AnyVal { + + /** + * Iterative application of `f` while `p` holds. + */ + def iterateWhileM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhileM(a)(f)(p) + + /** + * Iterative application of `f` until `p` holds. + */ + def iterateUntilM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntilM(a)(f)(p) +} diff --git a/tests/src/test/scala/cats/tests/MonadTest.scala b/tests/src/test/scala/cats/tests/MonadTest.scala index 063f389f9a..d38287556d 100644 --- a/tests/src/test/scala/cats/tests/MonadTest.scala +++ b/tests/src/test/scala/cats/tests/MonadTest.scala @@ -7,18 +7,20 @@ import org.scalacheck.Gen class MonadTest extends CatsSuite { implicit val testInstance: Monad[StateT[Id, Int, ?]] = StateT.catsDataMonadForStateT[Id, Int] + val smallPosInt = Gen.choose(1, 5000) + val increment: StateT[Id, Int, Unit] = StateT.modify(_ + 1) val incrementAndGet: StateT[Id, Int, Int] = increment >> StateT.get test("whileM_") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = increment.whileM_(StateT.inspect(i => !(i >= max))).run(0) result should ===(Math.max(0, max)) } } test("whileM") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, aggregation) = incrementAndGet.whileM[Vector](StateT.inspect(i => !(i >= max))).run(0) result should ===(Math.max(0, max)) aggregation should === ( if(max > 0) (1 to max).toVector else Vector.empty ) @@ -26,14 +28,14 @@ class MonadTest extends CatsSuite { } test("untilM_") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = increment.untilM_(StateT.inspect(_ >= max)).run(-1) result should ===(max) } } test("untilM") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, aggregation) = incrementAndGet.untilM[Vector](StateT.inspect(_ >= max)).run(-1) result should ===(max) aggregation should === ((0 to max).toVector) @@ -51,7 +53,7 @@ class MonadTest extends CatsSuite { } test("iterateWhile") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = incrementAndGet.iterateWhile(_ < max).run(-1) result should ===(Math.max(0, max)) } @@ -63,7 +65,7 @@ class MonadTest extends CatsSuite { } test("iterateUntil") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = incrementAndGet.iterateUntil(_ == max).run(-1) result should ===(Math.max(0, max)) } @@ -74,4 +76,28 @@ class MonadTest extends CatsSuite { result should ===(50000) } + test("iterateWhileM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateWhileM is stack safe") { + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + + test("iterateUntilM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateUntilM is stack safe") { + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + }