diff --git a/core/src/main/scala/retries.scala b/core/src/main/scala/retries.scala index 7a0067a..e2b820d 100644 --- a/core/src/main/scala/retries.scala +++ b/core/src/main/scala/retries.scala @@ -57,12 +57,14 @@ trait CountingRetry { promise: () => Future[T], success: Success[T], orElse: Int => Future[T] - )(implicit executor: ExecutionContext) = { + )(implicit executor: ExecutionContext): Future[T] = { val fut = promise() fut.flatMap { res => if (max < 1 || success.predicate(res)) fut - else orElse(max - 1) - } + else orElse(max - 1) + } recoverWith { + case e: Throwable => if (max < 1) fut else orElse(max - 1) + } } } diff --git a/core/src/test/scala/RetrySpec.scala b/core/src/test/scala/RetrySpec.scala index 7626494..e5e349a 100644 --- a/core/src/test/scala/RetrySpec.scala +++ b/core/src/test/scala/RetrySpec.scala @@ -2,7 +2,7 @@ package retry import org.scalatest.FunSpec import scala.annotation.tailrec -import scala.concurrent.{ Future, Await } +import scala.concurrent._ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global @@ -72,4 +72,22 @@ class RetrySpec extends FunSpec { "took more time than expected: %s" format took) } } + + describe("retry.CountingRetry") { + it ("should retry if an exception was thrown") { + import retry.Defaults.timer + implicit val success = new Success[Int](_ == 2) + def fut = () => future { 1 / 0 } + val took = time { + val result = try { + Await.result(retry.Pause(3, 30.millis)(fut), + 90.millis + 20.millis) + } catch { + case e: Throwable => e + } + } + assert(took >= 90.millis === true, + "took less time than expected: %s" format took) + } + } }