From 14e16950c9b2fdfac18ca191ddeedfa4624d1673 Mon Sep 17 00:00:00 2001 From: "David R. Bild" Date: Mon, 14 Aug 2017 10:46:36 -0500 Subject: [PATCH 1/4] Limit runtime of monad loop tests --- tests/src/test/scala/cats/tests/MonadTest.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/src/test/scala/cats/tests/MonadTest.scala b/tests/src/test/scala/cats/tests/MonadTest.scala index 063f389f9a..75c9ccdc55 100644 --- a/tests/src/test/scala/cats/tests/MonadTest.scala +++ b/tests/src/test/scala/cats/tests/MonadTest.scala @@ -7,18 +7,20 @@ import org.scalacheck.Gen class MonadTest extends CatsSuite { implicit val testInstance: Monad[StateT[Id, Int, ?]] = StateT.catsDataMonadForStateT[Id, Int] + val smallPosInt = Gen.choose(1, 5000) + val increment: StateT[Id, Int, Unit] = StateT.modify(_ + 1) val incrementAndGet: StateT[Id, Int, Int] = increment >> StateT.get test("whileM_") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = increment.whileM_(StateT.inspect(i => !(i >= max))).run(0) result should ===(Math.max(0, max)) } } test("whileM") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, aggregation) = incrementAndGet.whileM[Vector](StateT.inspect(i => !(i >= max))).run(0) result should ===(Math.max(0, max)) aggregation should === ( if(max > 0) (1 to max).toVector else Vector.empty ) @@ -26,14 +28,14 @@ class MonadTest extends CatsSuite { } test("untilM_") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = increment.untilM_(StateT.inspect(_ >= max)).run(-1) result should ===(max) } } test("untilM") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, aggregation) = incrementAndGet.untilM[Vector](StateT.inspect(_ >= max)).run(-1) result should ===(max) aggregation should === ((0 to max).toVector) @@ -51,7 +53,7 @@ class MonadTest extends CatsSuite { } test("iterateWhile") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = incrementAndGet.iterateWhile(_ < max).run(-1) result should ===(Math.max(0, max)) } @@ -63,7 +65,7 @@ class MonadTest extends CatsSuite { } test("iterateUntil") { - forAll(Gen.posNum[Int]) { (max: Int) => + forAll(smallPosInt) { (max: Int) => val (result, _) = incrementAndGet.iterateUntil(_ == max).run(-1) result should ===(Math.max(0, max)) } From 9a31560b765cead5e29a2fc4e6c75c4d5bd9181c Mon Sep 17 00:00:00 2001 From: "David R. Bild" Date: Tue, 15 Aug 2017 10:05:00 -0500 Subject: [PATCH 2/4] Do not use Either syntax in monad loops --- core/src/main/scala/cats/Monad.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index 84e688d8eb..3b65ef763c 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -1,7 +1,7 @@ package cats import simulacrum.typeclass -import syntax.either._ + /** * Monad. * @@ -30,7 +30,7 @@ import syntax.either._ Left(G.combineK(xs, G.pure(bv))) } }, - ifFalse = pure(xs.asRight[G[A]]) + ifFalse = pure(Right(xs)) )) } @@ -80,8 +80,8 @@ import syntax.either._ flatMap(f) { i => tailRecM(i) { a => if (p(a)) - map(f)(_.asLeft[A]) - else pure(a.asRight[A]) + map(f)(Left(_)) + else pure(Right(a)) } } } @@ -94,8 +94,8 @@ import syntax.either._ flatMap(f) { i => tailRecM(i) { a => if (p(a)) - pure(a.asRight[A]) - else map(f)(_.asLeft[A]) + pure(Right(a)) + else map(f)(Left(_)) } } } From 3f7dd4ddc49fa6f1fbcc0a9936589418c8a3e935 Mon Sep 17 00:00:00 2001 From: "David R. Bild" Date: Wed, 9 Aug 2017 12:59:53 -0500 Subject: [PATCH 3/4] Add iterateWhileM and iterateUntilM --- core/src/main/scala/cats/Monad.scala | 19 +++++++++++++++ core/src/main/scala/cats/syntax/monad.scala | 16 +++++++++++++ .../src/test/scala/cats/tests/MonadTest.scala | 24 +++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index 3b65ef763c..f3f364a2a0 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -100,4 +100,23 @@ import simulacrum.typeclass } } + /** + * Apply a monadic function iteratively until its result fails + * to satisfy the given predicate and return that result. + */ + def iterateWhileM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + tailRecM(init) { a => + if (p(a)) + map(f(a))(Left(_)) + else + pure(Right(a)) + } + + /** + * Apply a monadic function iteratively until its result satisfies + * the given predicate and return that result. + */ + def iterateUntilM[A](init: A)(f: A => F[A])(p: A => Boolean): F[A] = + iterateWhileM(init)(f)(!p(_)) + } diff --git a/core/src/main/scala/cats/syntax/monad.scala b/core/src/main/scala/cats/syntax/monad.scala index 2772c5a8ed..efdd15235f 100644 --- a/core/src/main/scala/cats/syntax/monad.scala +++ b/core/src/main/scala/cats/syntax/monad.scala @@ -3,6 +3,9 @@ package syntax trait MonadSyntax { implicit final def catsSyntaxMonad[F[_], A](fa: F[A]): MonadOps[F, A] = new MonadOps(fa) + + implicit final def catsSyntaxMonadIdOps[A](a: A): MonadIdOps[A] = + new MonadIdOps[A](a) } final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { @@ -13,3 +16,16 @@ final class MonadOps[F[_], A](val fa: F[A]) extends AnyVal { def iterateWhile(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhile(fa)(p) def iterateUntil(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntil(fa)(p) } + +final class MonadIdOps[A](val a: A) extends AnyVal { + + /** + * Iterative application of `f` while `p` holds. + */ + def iterateWhileM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateWhileM(a)(f)(p) + + /** + * Iterative application of `f` until `p` holds. + */ + def iterateUntilM[F[_]](f: A => F[A])(p: A => Boolean)(implicit M: Monad[F]): F[A] = M.iterateUntilM(a)(f)(p) +} diff --git a/tests/src/test/scala/cats/tests/MonadTest.scala b/tests/src/test/scala/cats/tests/MonadTest.scala index 75c9ccdc55..d38287556d 100644 --- a/tests/src/test/scala/cats/tests/MonadTest.scala +++ b/tests/src/test/scala/cats/tests/MonadTest.scala @@ -76,4 +76,28 @@ class MonadTest extends CatsSuite { result should ===(50000) } + test("iterateWhileM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateWhileM is stack safe") { + val (n, sum) = 0.iterateWhileM(s => incrementAndGet map (_ + s))(_ < 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + + test("iterateUntilM") { + forAll(smallPosInt) { (max: Int) => + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > max).run(0) + sum should ===(n * (n + 1) / 2) + } + } + + test("iterateUntilM is stack safe") { + val (n, sum) = 0.iterateUntilM(s => incrementAndGet map (_ + s))(_ > 50000000).run(0) + sum should ===(n * (n + 1) / 2) + } + } From 60d8e5af0c1c6b0677d2892869894a4b804be445 Mon Sep 17 00:00:00 2001 From: "David R. Bild" Date: Wed, 9 Aug 2017 13:00:41 -0500 Subject: [PATCH 4/4] Implement iterateWhile in terms of iterateWhileM --- core/src/main/scala/cats/Monad.scala | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/cats/Monad.scala b/core/src/main/scala/cats/Monad.scala index f3f364a2a0..fa4d05079c 100644 --- a/core/src/main/scala/cats/Monad.scala +++ b/core/src/main/scala/cats/Monad.scala @@ -76,29 +76,19 @@ import simulacrum.typeclass * Execute an action repeatedly until its result fails to satisfy the given predicate * and return that result, discarding all others. */ - def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] = { + def iterateWhile[A](f: F[A])(p: A => Boolean): F[A] = flatMap(f) { i => - tailRecM(i) { a => - if (p(a)) - map(f)(Left(_)) - else pure(Right(a)) - } + iterateWhileM(i)(_ => f)(p) } - } /** * Execute an action repeatedly until its result satisfies the given predicate * and return that result, discarding all others. */ - def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] = { + def iterateUntil[A](f: F[A])(p: A => Boolean): F[A] = flatMap(f) { i => - tailRecM(i) { a => - if (p(a)) - pure(Right(a)) - else map(f)(Left(_)) - } + iterateUntilM(i)(_ => f)(p) } - } /** * Apply a monadic function iteratively until its result fails