diff --git a/src/main/scala/com/github/yruslan/channel/ChannelDecoratorFilter.scala b/src/main/scala/com/github/yruslan/channel/ChannelDecoratorFilter.scala index f8fe576..e979662 100644 --- a/src/main/scala/com/github/yruslan/channel/ChannelDecoratorFilter.scala +++ b/src/main/scala/com/github/yruslan/channel/ChannelDecoratorFilter.scala @@ -53,21 +53,18 @@ class ChannelDecoratorFilter[T](inputChannel: ReadChannel[T], pred: T => Boolean val timeoutMilli = if (timeout.isFinite) timeout.toMillis else 0L val startInstant = Instant.now() - var valueOpt = inputChannel.tryRecv(timeout) - var found = valueOpt.isEmpty || valueOpt.forall(v => pred(v)) - var elapsedTime = java.time.Duration.between(startInstant, now).toMillis + var elapsedTime = 0L - if (found || elapsedTime >= timeoutMilli) { - valueOpt - } else { - while (!found && elapsedTime < timeoutMilli) { - val newTimeout = Duration(timeoutMilli - elapsedTime, MILLISECONDS) - valueOpt = inputChannel.tryRecv(newTimeout) - found = valueOpt.isEmpty || valueOpt.forall(v => pred(v)) - elapsedTime = java.time.Duration.between(startInstant, now).toMillis + while (elapsedTime <= timeoutMilli) { + val newTimeout = Duration(timeoutMilli - elapsedTime, MILLISECONDS) + val valueOpt = inputChannel.tryRecv(newTimeout) + val found = valueOpt.isEmpty || valueOpt.forall(v => pred(v)) + elapsedTime = java.time.Duration.between(startInstant, now).toMillis + if (found) { + return valueOpt } - valueOpt } + None } override def recver(action: T => Unit): Selector = inputChannel.recver(t => if (pred(t)) action(t)) diff --git a/src/test/scala/com/github/yruslan/channel/ChannelFilterSuite.scala b/src/test/scala/com/github/yruslan/channel/ChannelFilterSuite.scala index 748b268..ed33038 100644 --- a/src/test/scala/com/github/yruslan/channel/ChannelFilterSuite.scala +++ b/src/test/scala/com/github/yruslan/channel/ChannelFilterSuite.scala @@ -17,6 +17,7 @@ package com.github.yruslan.channel import org.scalatest.wordspec.AnyWordSpec +import java.time.Instant import java.util.concurrent.Executors import scala.concurrent._ import scala.concurrent.duration.{Duration, MILLISECONDS} @@ -95,7 +96,7 @@ class ChannelFilterSuite extends AnyWordSpec { } "filter input channel on tryRecv(duration)" when { - val timeout = Duration(2, MILLISECONDS) + val timeout = Duration(200, MILLISECONDS) "values either available or not" in { val ch1 = Channel.make[Int](3) @@ -142,6 +143,62 @@ class ChannelFilterSuite extends AnyWordSpec { assert(v1.contains(3)) } + "filter the correct value even with 0 millisecond timeout" in { + val ch1 = Channel.make[Int](2) + + val ch2 = ch1.filter(v => v == 2) + + ch1.send(1) + ch1.send(2) + + val v1 = ch2.tryRecv(Duration.Zero) + ch1.close() + + assert(v1.contains(2)) + } + + "return None if no values match and zero timeout" in { + val ch1 = Channel.make[Int](2) + + val ch2 = ch1.filter(v => v == 3) + + ch1.send(1) + ch1.send(2) + + val v1 = ch2.tryRecv(Duration.Zero) + ch1.close() + + assert(v1.isEmpty) + } + + "return instantly on empty channel and zero timeout" in { + val ch1 = Channel.make[Int](2) + + val ch2 = ch1.filter(v => v == 3) + + val start = Instant.now() + val v1 = ch2.tryRecv(Duration.Zero) + val finish = Instant.now() + + assert(v1.isEmpty) + assert(java.time.Duration.between(start, finish).toMillis <= 10L) + } + + "return None after proper wait for a non-zero timeout" in { + val ch1 = Channel.make[Int](2) + + val ch2 = ch1.filter(v => v == 3) + + ch1.send(1) + ch1.send(2) + + val start = Instant.now() + val v1 = ch2.tryRecv(Duration(10, MILLISECONDS)) + val finish = Instant.now() + + assert(v1.isEmpty) + assert(java.time.Duration.between(start, finish).toMillis >= 10L) + } } "filter input channel on recver()" in {