Skip to content

Unify timeout checks in exception handling methods. #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions bson/src/main/org/bson/assertions/Assertions.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ public static <T> T assertNotNull(@Nullable final T value) throws AssertionError
return value;
}

/**
* Throw AssertionError if the condition if false.
*
* @param name the name of the state that is being checked
* @param condition the condition about the parameter to check
* @throws AssertionError if the condition is false
*/
public static void assertTrue(final String name, final boolean condition) {
if (!condition) {
throw new AssertionError("state should be: " + assertNotNull(name));
}
}

/**
* Cast an object to the given class and return it, or throw IllegalArgumentException if it's not assignable to that class.
*
Expand Down
8 changes: 4 additions & 4 deletions driver-core/src/main/com/mongodb/internal/TimeoutContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Objects;
import java.util.function.LongConsumer;

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.assertions.Assertions.assertNull;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;
Expand Down Expand Up @@ -253,9 +252,10 @@ public int getConnectTimeoutMs() {
() -> throwMongoTimeoutException("The operation exceeded the timeout limit.")));
}

public void resetTimeout() {
assertNotNull(timeout);
timeout = startTimeout(timeoutSettings.getTimeoutMS());
public void resetTimeoutIfPresent() {
if (hasTimeoutMS()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: The method naming makes the assertNotNull seem restrictive. Should it be part of the if statement?

I guess its a bug if this is called if there isnt a timeout set - its just slightly unclear thats the case.

timeout = startTimeout(timeoutSettings.getTimeoutMS());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ public void selectServerAsync(final ServerSelector serverSelector, final Operati
final SingleResultCallback<ServerTuple> callback) {
isTrue("open", !isClosed());

//TODO (CSOT) is it safe to put this before phase.get()? P.S it was after in pre-CSOT state.
Timeout computedServerSelectionTimeout = operationContext.getTimeoutContext().computeServerSelectionTimeout();
ServerSelectionRequest request = new ServerSelectionRequest(
serverSelector, operationContext, computedServerSelectionTimeout, callback);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ public void sendMessageAsync(
stream.writeAsync(byteBuffers, operationContext, c.asHandler());
}, Exception.class, (e, c) -> {
close();
//TODO-m propably should be solved after merge
throwTranslatedWriteException(e, operationContext);
}).finish(errorHandlingCallback(callback, LOGGER));
}
Expand Down Expand Up @@ -771,10 +770,8 @@ private void updateSessionContext(final SessionContext sessionContext, final Res
}

private void throwTranslatedWriteException(final Throwable e, final OperationContext operationContext) {
if (e instanceof MongoSocketWriteTimeoutException) {
operationContext.getTimeoutContext().onExpired(() -> {
throw createMongoTimeoutException(e);
});
if (e instanceof MongoSocketWriteTimeoutException && operationContext.getTimeoutContext().hasTimeoutMS()) {
throw createMongoTimeoutException(e);
}

if (e instanceof MongoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void next(final SingleResultCallback<List<T>> callback) {

@Override
public void close() {
resetTimeout();
timeoutContext.resetTimeoutIfPresent();
if (isClosed.compareAndSet(false, true)) {
try {
nullifyAndCloseWrapped();
Expand Down Expand Up @@ -181,7 +181,7 @@ private interface AsyncBlock {
}

private void resumeableOperation(final AsyncBlock asyncBlock, final SingleResultCallback<List<T>> callback, final boolean tryNext) {
resetTimeout();
timeoutContext.resetTimeoutIfPresent();
SingleResultCallback<List<T>> errHandlingCallback = errorHandlingCallback(callback, LOGGER);
if (isClosed()) {
errHandlingCallback.onResult(null, new MongoException(format("%s called after the cursor was closed.",
Expand Down Expand Up @@ -242,10 +242,4 @@ private void retryOperation(final AsyncBlock asyncBlock, final SingleResultCallb
}
});
}

private void resetTimeout() {
if (timeoutContext.hasTimeoutMS()) {
timeoutContext.resetTimeout();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public List<T> tryNext() {
@Override
public void close() {
if (!closed.getAndSet(true)) {
resetTimeout();
timeoutContext.resetTimeoutIfPresent();
wrapped.close();
binding.release();
}
Expand Down Expand Up @@ -211,7 +211,7 @@ static <T> List<T> convertAndProduceLastId(final List<RawBsonDocument> rawDocume
}

<R> R resumeableOperation(final Function<AggregateResponseBatchCursor<RawBsonDocument>, R> function) {
resetTimeout();
timeoutContext.resetTimeoutIfPresent();
try {
R result = execute(function);
lastOperationTimedOut = false;
Expand Down Expand Up @@ -254,12 +254,6 @@ private boolean hasPreviousNextTimedOut() {
return lastOperationTimedOut && !closed.get();
}

private void resetTimeout() {
if (timeoutContext.hasTimeoutMS()) {
timeoutContext.resetTimeout();
}
}

private static boolean isTimeoutException(final Throwable exception) {
return exception instanceof MongoOperationTimeoutException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ void checkTimeoutModeAndResetTimeoutContextIfIteration() {
}

void resetTimeout() {
if (!closeWithoutTimeoutReset && timeoutContext.hasTimeoutMS()) {
timeoutContext.resetTimeout();
if (!closeWithoutTimeoutReset) {
timeoutContext.resetTimeoutIfPresent();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ protected void setTimeoutContext(@Nullable final TimeoutContext timeoutContext)
}

protected void resetTimeout() {
if (timeoutContext != null && timeoutContext.hasTimeoutMS()) {
timeoutContext.resetTimeout();
if (timeoutContext != null) {
timeoutContext.resetTimeoutIfPresent();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ static Timeout nullAsInfinite(@Nullable final Timeout timeout) {
*/
@NotNull
static Timeout expiresIn(final long duration, final TimeUnit unit, final ZeroSemantics zeroSemantics) {
// TODO (CSOT) confirm that all usages in final PR always supply a non-negative duration
if (duration < 0) {
throw new AssertionError("Timeouts must not be in the past");
} else if (duration == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ public static ReadWriteBinding getBinding(final TimeoutSettings timeoutSettings)
return getBinding(getCluster(), ReadPreference.primary(), createNewOperationContext(timeoutSettings));
}

public static ReadWriteBinding getBinding(final OperationContext operationContext) {
return getBinding(getCluster(), ReadPreference.primary(), operationContext);
}

public static ReadWriteBinding getBinding(final ReadPreference readPreference) {
return getBinding(getCluster(), readPreference, OPERATION_CONTEXT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import com.mongodb.ReadPreference
import com.mongodb.ReadPreferenceHedgeOptions
import com.mongodb.ServerAddress
import com.mongodb.async.FutureResultCallback
import com.mongodb.client.model.CreateCollectionOptions
import com.mongodb.connection.ClusterId
import com.mongodb.connection.ConnectionDescription
import com.mongodb.connection.ConnectionId
import com.mongodb.connection.ServerId
import com.mongodb.internal.TimeoutContext
import com.mongodb.internal.binding.AsyncClusterBinding
import com.mongodb.internal.binding.AsyncConnectionSource
import com.mongodb.internal.binding.AsyncReadBinding
Expand All @@ -53,13 +55,16 @@ import spock.lang.IgnoreIf
import static com.mongodb.ClusterFixture.OPERATION_CONTEXT
import static com.mongodb.ClusterFixture.executeAsync
import static com.mongodb.ClusterFixture.executeSync
import static com.mongodb.ClusterFixture.getAsyncBinding
import static com.mongodb.ClusterFixture.getAsyncCluster
import static com.mongodb.ClusterFixture.getBinding
import static com.mongodb.ClusterFixture.getCluster
import static com.mongodb.ClusterFixture.isSharded
import static com.mongodb.ClusterFixture.serverVersionAtLeast
import static com.mongodb.ClusterFixture.serverVersionLessThan
import static com.mongodb.CursorType.NonTailable
import static com.mongodb.CursorType.Tailable
import static com.mongodb.CursorType.TailableAwait
import static com.mongodb.connection.ServerType.STANDALONE
import static com.mongodb.internal.operation.OperationReadConcernHelper.appendReadConcernToCommand
import static com.mongodb.internal.operation.ServerVersionHelper.MIN_WIRE_VERSION
Expand Down Expand Up @@ -643,30 +648,33 @@ class FindOperationSpecification extends OperationFunctionalSpecification {
}

// sanity check that the server accepts tailable and await data flags
// TODO (CSOT) JAVA-4058
/*
def 'should pass tailable and await data flags through'() {
given:
def (cursorType, maxAwaitTimeMS, maxTimeMSForCursor) = cursorDetails
def (cursorType, long maxAwaitTimeMS, long maxTimeMSForCursor) = cursorDetails
def timeoutSettings = ClusterFixture.TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT.withMaxAwaitTimeMS(maxAwaitTimeMS)
def timeoutContext = Spy(TimeoutContext, constructorArgs: [timeoutSettings])
def operationContext = OPERATION_CONTEXT.withTimeoutContext(timeoutContext)

collectionHelper.create(getCollectionName(), new CreateCollectionOptions().capped(true).sizeInBytes(1000))
def operation = new FindOperation<BsonDocument>(TIMEOUT_SETTINGS_WITH_MAX_TIME, namespace, new BsonDocumentCodec())
def operation = new FindOperation<BsonDocument>(namespace, new BsonDocumentCodec())
.cursorType(cursorType)

when:
def cursor = execute(operation, async)
if (async) {
execute(operation, getBinding(operationContext))
} else {
execute(operation, getAsyncBinding(operationContext))
}

then:
println cursor
// TODO (CSOT) JAVA-4058
cursor.maxTimeMS == maxTimeMSForCursor
timeoutContext.setMaxTimeOverride(maxTimeMSForCursor)

where:
[async, cursorDetails] << [
[true, false],
[[NonTailable, 100, 0], [Tailable, 100, 0], [TailableAwait, 100, 100]]
].combinations()
}
*/

// sanity check that the server accepts the miscallaneous flags
def 'should pass miscallaneous flags through'() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void shouldReturnResultOnNext() {

//then
assertEquals(RESULT_FROM_NEW_CURSOR, next);
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(commandBatchCursor, times(1)).next();
verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken();
verifyNoMoreInteractions(commandBatchCursor);
Expand All @@ -98,7 +98,7 @@ void shouldThrowTimeoutExceptionWithoutResumeAttemptOnNext() {
assertThrows(MongoOperationTimeoutException.class, cursor::next);

//then
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(commandBatchCursor, times(1)).next();
verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken();
verifyNoMoreInteractions(commandBatchCursor);
Expand All @@ -115,7 +115,7 @@ void shouldPerformResumeAttemptOnNextWhenResumableErrorIsThrown() {

//then
assertEquals(RESULT_FROM_NEW_CURSOR, next);
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(commandBatchCursor, times(1)).next();
verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken();
verifyResumeAttemptCalled();
Expand All @@ -136,7 +136,7 @@ void shouldResumeOnlyOnceOnSubsequentCallsAfterTimeoutError() {
assertThrows(MongoOperationTimeoutException.class, cursor::next);

//then
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(commandBatchCursor, times(1)).next();
verify(commandBatchCursor, atLeastOnce()).getPostBatchResumeToken();
verifyNoMoreInteractions(commandBatchCursor);
Expand All @@ -148,7 +148,7 @@ void shouldResumeOnlyOnceOnSubsequentCallsAfterTimeoutError() {

//then
assertEquals(Collections.emptyList(), next);
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(commandBatchCursor, times(1)).close();
verifyNoMoreInteractions(commandBatchCursor);
verify(changeStreamOperation).setChangeStreamOptionsForResume(resumeToken, maxWireVersion);
Expand All @@ -165,7 +165,7 @@ void shouldResumeOnlyOnceOnSubsequentCallsAfterTimeoutError() {
//then
assertEquals(Collections.emptyList(), next2);
verifyNoInteractions(commandBatchCursor);
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verify(newCommandBatchCursor, times(1)).next();
verify(newCommandBatchCursor, atLeastOnce()).getPostBatchResumeToken();
verifyNoMoreInteractions(newCommandBatchCursor);
Expand All @@ -189,7 +189,7 @@ void shouldPropagateAnyErrorsOccurredInAggregateOperation() {
assertThrows(MongoNotPrimaryException.class, cursor::next);

//then
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verifyResumeAttemptCalled();
verifyNoMoreInteractions(changeStreamOperation);
verifyNoInteractions(newCommandBatchCursor);
Expand Down Expand Up @@ -219,7 +219,7 @@ void shouldResumeAfterTimeoutInAggregateOnNextCall() {

//then
assertEquals(RESULT_FROM_NEW_CURSOR, next);
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();

verifyResumeAttemptCalled();
verify(changeStreamOperation, times(1)).getDecoder();
Expand All @@ -246,7 +246,7 @@ void shouldCloseChangeStreamWhenResumeOperationFailsDueToNonTimeoutError() {
assertThrows(MongoNotPrimaryException.class, cursor::next);

//then
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verifyResumeAttemptCalled();
verifyNoMoreInteractions(changeStreamOperation);
verifyNoInteractions(newCommandBatchCursor);
Expand All @@ -259,7 +259,7 @@ void shouldCloseChangeStreamWhenResumeOperationFailsDueToNonTimeoutError() {

//then
assertEquals(MESSAGE_IF_CLOSED_AS_CURSOR, mongoException.getMessage());
verify(timeoutContext, times(1)).resetTimeout();
verify(timeoutContext, times(1)).resetTimeoutIfPresent();
verifyNoResumeAttemptCalled();
}

Expand Down Expand Up @@ -293,7 +293,7 @@ void setUp() {

timeoutContext = mock(TimeoutContext.class);
when(timeoutContext.hasTimeoutMS()).thenReturn(true);
doNothing().when(timeoutContext).resetTimeout();
doNothing().when(timeoutContext).resetTimeoutIfPresent();

operationContext = mock(OperationContext.class);
when(operationContext.getTimeoutContext()).thenReturn(timeoutContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.bson.assertions.Assertions.assertTrue;

class KeyManagementService implements Closeable {
private static final Logger LOGGER = Loggers.getLogger("client");
Expand All @@ -61,6 +62,7 @@ class KeyManagementService implements Closeable {
private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;

KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
assertTrue("timeoutMillis > 0", timeoutMillis > 0);
this.kmsProviderSslContextMap = kmsProviderSslContextMap;
this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
this.timeoutMillis = timeoutMillis;
Expand Down Expand Up @@ -166,7 +168,6 @@ private OperationContext createOperationContext(@Nullable final Timeout operatio
timeoutSettings = createTimeoutSettings(socketSettings, null);
} else {
timeoutSettings = operationTimeout.call(MILLISECONDS,
// TODO (CSOT) JAVA-5104 correct that cannot be infinite? Possibly a path here from: Timeout operationTimeout = operationContext.getTimeoutContext().getTimeout();
() -> {
throw new AssertionError("operationTimeout cannot be infinite");
},
Expand Down