From 09061e2661e216194aa82fd4c5991d86de5bd128 Mon Sep 17 00:00:00 2001 From: "slav.babanin" Date: Wed, 20 Mar 2024 18:31:24 -0700 Subject: [PATCH] Override collection/database with infinite timeout. JAVA-5384 --- .../client/internal/TimeoutHelper.java | 11 +++--- .../client/internal/TimeoutHelperTest.java | 39 ++++++++++++------- .../client/internal/TimeoutHelper.java | 11 +++--- .../client/internal/TimeoutHelperTest.java | 27 ++++++++----- 4 files changed, 51 insertions(+), 37 deletions(-) diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java index f2039fdf2f9..5b67890c53a 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/TimeoutHelper.java @@ -36,6 +36,9 @@ private TimeoutHelper() { } public static long getRemainingMs(final String message, final Timeout timeout) { + if (timeout.isInfinite()) { + return 0; + } long remainingMs = timeout.remaining(MILLISECONDS); if (remainingMs <= 0) { throw new MongoOperationTimeoutException(message); @@ -45,7 +48,7 @@ public static long getRemainingMs(final String message, final Timeout timeout) { public static MongoCollection collectionWithTimeout(final MongoCollection collection, @Nullable final Timeout timeout) { - if (shouldOverrideTimeout(timeout)) { + if (timeout != null) { long remainingMs = getRemainingMs(DEFAULT_TIMEOUT_MESSAGE, timeout); return collection.withTimeout(remainingMs, MILLISECONDS); } @@ -75,7 +78,7 @@ public static MongoDatabase databaseWithTimeout(final MongoDatabase database, public static MongoDatabase databaseWithTimeout(final MongoDatabase database, final String message, @Nullable final Timeout timeout) { - if (shouldOverrideTimeout(timeout)) { + if (timeout != null) { long remainingMs = getRemainingMs(message, timeout); return database.withTimeout(remainingMs, MILLISECONDS); } @@ -102,8 +105,4 @@ public static Mono databaseWithTimeoutDeferred(final MongoDatabas @Nullable final Timeout timeout) { return Mono.defer(() -> databaseWithTimeoutMono(database, message, timeout)); } - - private static boolean shouldOverrideTimeout(@Nullable final Timeout timeout) { - return timeout != null && !timeout.isInfinite(); - } } diff --git a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/TimeoutHelperTest.java b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/TimeoutHelperTest.java index 4012a52b6e2..87f0ef16c76 100644 --- a/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/TimeoutHelperTest.java +++ b/driver-reactive-streams/src/test/unit/com/mongodb/reactivestreams/client/internal/TimeoutHelperTest.java @@ -32,7 +32,6 @@ import static com.mongodb.reactivestreams.client.internal.TimeoutHelper.databaseWithTimeout; import static com.mongodb.reactivestreams.client.internal.TimeoutHelper.databaseWithTimeoutDeferred; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -81,7 +80,10 @@ void shouldNotSetRemainingTimeoutDatabaseWhenTimeoutIsNull() { @Test void shouldNotSetRemainingTimeoutOnCollectionWhenTimeoutIsInfinite() { //given - MongoCollection collection = mock(MongoCollection.class); + MongoCollection collectionWithTimeout = mock(MongoCollection.class); + MongoCollection collection = mock(MongoCollection.class, mongoCollection -> { + when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(collectionWithTimeout); + }); //when MongoCollection result = collectionWithTimeout(collection, Timeout.infinite()); @@ -89,23 +91,30 @@ void shouldNotSetRemainingTimeoutOnCollectionWhenTimeoutIsInfinite() { MongoCollection monoResultDeferred = collectionWithTimeoutDeferred(collection, Timeout.infinite()).block(); //then - assertEquals(collection, result); - assertEquals(collection, monoResult); - assertEquals(collection, monoResultDeferred); + assertEquals(collectionWithTimeout, result); + assertEquals(collectionWithTimeout, monoResult); + assertEquals(collectionWithTimeout, monoResultDeferred); + verify(collection, times(3)) + .withTimeout(0L, TimeUnit.MILLISECONDS); } @Test void shouldNotSetRemainingTimeoutOnDatabaseWhenTimeoutIsInfinite() { //given - MongoDatabase database = mock(MongoDatabase.class); + MongoDatabase databaseWithTimeout = mock(MongoDatabase.class); + MongoDatabase database = mock(MongoDatabase.class, mongoDatabase -> { + when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(databaseWithTimeout); + }); //when MongoDatabase result = databaseWithTimeout(database, TIMEOUT_ERROR_MESSAGE, Timeout.infinite()); MongoDatabase monoResultDeferred = databaseWithTimeoutDeferred(database, TIMEOUT_ERROR_MESSAGE, Timeout.infinite()).block(); //then - assertEquals(database, result); - assertEquals(database, monoResultDeferred); + assertEquals(databaseWithTimeout, result); + assertEquals(databaseWithTimeout, monoResultDeferred); + verify(database, times(2)) + .withTimeout(0L, TimeUnit.MILLISECONDS); } @Test @@ -113,7 +122,7 @@ void shouldSetRemainingTimeoutOnCollectionWhenTimeout() { //given MongoCollection collectionWithTimeout = mock(MongoCollection.class); MongoCollection collection = mock(MongoCollection.class, mongoCollection -> { - when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(mongoCollection); + when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(collectionWithTimeout); }); Timeout timeout = Timeout.expiresIn(1, TimeUnit.DAYS); @@ -125,9 +134,9 @@ void shouldSetRemainingTimeoutOnCollectionWhenTimeout() { //then verify(collection, times(3)) .withTimeout(longThat(remaining -> remaining > 0), eq(TimeUnit.MILLISECONDS)); - assertNotEquals(collectionWithTimeout, result); - assertNotEquals(collectionWithTimeout, monoResult); - assertNotEquals(collectionWithTimeout, monoResultDeferred); + assertEquals(collectionWithTimeout, result); + assertEquals(collectionWithTimeout, monoResult); + assertEquals(collectionWithTimeout, monoResultDeferred); } @Test @@ -135,7 +144,7 @@ void shouldSetRemainingTimeoutOnDatabaseWhenTimeout() { //given MongoDatabase databaseWithTimeout = mock(MongoDatabase.class); MongoDatabase database = mock(MongoDatabase.class, mongoDatabase -> { - when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(mongoDatabase); + when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(databaseWithTimeout); }); Timeout timeout = Timeout.expiresIn(1, TimeUnit.DAYS); @@ -146,8 +155,8 @@ void shouldSetRemainingTimeoutOnDatabaseWhenTimeout() { //then verify(database, times(2)) .withTimeout(longThat(remaining -> remaining > 0), eq(TimeUnit.MILLISECONDS)); - assertNotEquals(databaseWithTimeout, result); - assertNotEquals(databaseWithTimeout, monoResultDeferred); + assertEquals(databaseWithTimeout, result); + assertEquals(databaseWithTimeout, monoResultDeferred); } @Test diff --git a/driver-sync/src/main/com/mongodb/client/internal/TimeoutHelper.java b/driver-sync/src/main/com/mongodb/client/internal/TimeoutHelper.java index b1822c64fdd..6dba03cc0e2 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/TimeoutHelper.java +++ b/driver-sync/src/main/com/mongodb/client/internal/TimeoutHelper.java @@ -37,7 +37,7 @@ private TimeoutHelper() { public static MongoCollection collectionWithTimeout(final MongoCollection collection, final String message, @Nullable final Timeout timeout) { - if (shouldOverrideTimeout(timeout)) { + if (timeout != null) { long remainingMs = getRemainingMs(timeout, message); return collection.withTimeout(remainingMs, MILLISECONDS); } @@ -52,7 +52,7 @@ public static MongoCollection collectionWithTimeout(final MongoCollection public static MongoDatabase databaseWithTimeout(final MongoDatabase database, final String message, @Nullable final Timeout timeout) { - if (shouldOverrideTimeout(timeout)) { + if (timeout != null) { long remainingMs = getRemainingMs(timeout, message); return database.withTimeout(remainingMs, MILLISECONDS); } @@ -65,14 +65,13 @@ public static MongoDatabase databaseWithTimeout(final MongoDatabase database, } private static long getRemainingMs(final Timeout timeout, final String message) { + if (timeout.isInfinite()) { + return 0; + } long remainingMs = timeout.remaining(MILLISECONDS); if (remainingMs <= 0) { throw TimeoutContext.createMongoTimeoutException(message); } return remainingMs; } - - private static boolean shouldOverrideTimeout(@Nullable final Timeout timeout) { - return timeout != null && !timeout.isInfinite(); - } } diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/TimeoutHelperTest.java b/driver-sync/src/test/unit/com/mongodb/client/internal/TimeoutHelperTest.java index f26e130e63d..b1fbceeda3c 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/TimeoutHelperTest.java +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/TimeoutHelperTest.java @@ -29,7 +29,6 @@ import static com.mongodb.client.internal.TimeoutHelper.databaseWithTimeout; import static com.mongodb.internal.mockito.MongoMockito.mock; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -68,27 +67,35 @@ void shouldNotSetRemainingTimeoutDatabaseWhenTimeoutIsNull() { } @Test - void shouldNotSetRemainingTimeoutOnCollectionWhenTimeoutIsInfinite() { + void shouldSetRemainingTimeoutOnCollectionWhenTimeoutIsInfinite() { //given - MongoCollection collection = mock(MongoCollection.class); + MongoCollection collectionWithTimeout = mock(MongoCollection.class); + MongoCollection collection = mock(MongoCollection.class, mongoCollection -> { + when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(collectionWithTimeout); + }); //when MongoCollection result = collectionWithTimeout(collection, TIMEOUT_ERROR_MESSAGE, Timeout.infinite()); //then - assertEquals(collection, result); + assertEquals(collectionWithTimeout, result); + verify(collection).withTimeout(0L, TimeUnit.MILLISECONDS); } @Test void shouldNotSetRemainingTimeoutOnDatabaseWhenTimeoutIsInfinite() { //given - MongoDatabase database = mock(MongoDatabase.class); + MongoDatabase databaseWithTimeout = mock(MongoDatabase.class); + MongoDatabase database = mock(MongoDatabase.class, mongoDatabase -> { + when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(databaseWithTimeout); + }); //when MongoDatabase result = databaseWithTimeout(database, TIMEOUT_ERROR_MESSAGE, Timeout.infinite()); //then - assertEquals(database, result); + assertEquals(databaseWithTimeout, result); + verify(database).withTimeout(0L, TimeUnit.MILLISECONDS); } @Test @@ -96,7 +103,7 @@ void shouldSetRemainingTimeoutOnCollectionWhenTimeout() { //given MongoCollection collectionWithTimeout = mock(MongoCollection.class); MongoCollection collection = mock(MongoCollection.class, mongoCollection -> { - when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(mongoCollection); + when(mongoCollection.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(collectionWithTimeout); }); Timeout timeout = Timeout.expiresIn(1, TimeUnit.DAYS); @@ -105,7 +112,7 @@ void shouldSetRemainingTimeoutOnCollectionWhenTimeout() { //then verify(collection).withTimeout(longThat(remaining -> remaining > 0), eq(TimeUnit.MILLISECONDS)); - assertNotEquals(collectionWithTimeout, result); + assertEquals(collectionWithTimeout, result); } @Test @@ -113,7 +120,7 @@ void shouldSetRemainingTimeoutOnDatabaseWhenTimeout() { //given MongoDatabase databaseWithTimeout = mock(MongoDatabase.class); MongoDatabase database = mock(MongoDatabase.class, mongoDatabase -> { - when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(mongoDatabase); + when(mongoDatabase.withTimeout(anyLong(), eq(TimeUnit.MILLISECONDS))).thenReturn(databaseWithTimeout); }); Timeout timeout = Timeout.expiresIn(1, TimeUnit.DAYS); @@ -122,7 +129,7 @@ void shouldSetRemainingTimeoutOnDatabaseWhenTimeout() { //then verify(database).withTimeout(longThat(remaining -> remaining > 0), eq(TimeUnit.MILLISECONDS)); - assertNotEquals(databaseWithTimeout, result); + assertEquals(databaseWithTimeout, result); } @Test